f5tts/app/f5tts.py

287 lines
7.7 KiB
Python
Raw Normal View History

2024-12-25 16:04:39 +08:00
import os
2024-10-22 19:13:23 +08:00
import sys
import argparse
2024-10-21 18:29:57 +08:00
import codecs
import re
2024-10-22 19:13:23 +08:00
from pathlib import Path
2024-12-25 16:04:39 +08:00
from functools import partial
2024-10-22 19:13:23 +08:00
2024-10-21 18:29:57 +08:00
import numpy as np
import soundfile as sf
2024-12-19 16:50:53 +08:00
# import tomli
2024-10-21 18:29:57 +08:00
from cached_path import cached_path
2024-12-25 16:04:39 +08:00
import pycld2 as cld
import cn2an
2024-10-21 18:29:57 +08:00
2024-12-19 16:50:53 +08:00
from f5_tts.model import DiT, UNetT
2024-12-25 16:04:39 +08:00
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
infer_process,
2024-10-22 22:59:29 +08:00
load_model,
2024-12-25 16:04:39 +08:00
load_vocoder,
2024-10-22 22:59:29 +08:00
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
2024-10-22 19:13:23 +08:00
)
2024-10-21 18:29:57 +08:00
import json
2024-12-25 16:04:39 +08:00
from time import time, sleep
2024-10-21 18:29:57 +08:00
from appPublic.dictObject import DictObject
2024-10-22 22:59:29 +08:00
from appPublic.folderUtils import temp_file
2024-10-21 18:29:57 +08:00
from appPublic.jsonConfig import getConfig
2024-12-19 16:50:53 +08:00
from appPublic.worker import awaitify
2024-12-25 16:04:39 +08:00
from appPublic.uniqueID import getID
from appPublic.log import debug, info
from appPublic.background import Background
2024-12-19 16:50:53 +08:00
from ahserver.webapp import webapp
2024-12-25 16:04:39 +08:00
from ahserver.serverenv import ServerEnv
from ahserver.filestorage import FileStorage
2024-10-21 18:29:57 +08:00
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
2024-12-25 16:04:39 +08:00
def write_wav_buffer(wav, nchannels, framerate):
fs = FileStorage()
fn = fs._name2path(f'{getID()}.wav', userid='tmp')
os.makedirs(os.path.dirname(fn))
debug(fn)
with open(fn, "wb") as f:
sf.write(f.name, wav, framerate)
return fs.webpath(fn)
async_write_wav_buffer = awaitify(write_wav_buffer)
def detect_language(txt):
isReliable, textBytesFound, details = cld.detect(txt)
debug(f' detect_language():{isReliable=}, {textBytesFound=}, {details=} ')
return details[0][1]
2024-10-21 18:29:57 +08:00
class F5TTS:
def __init__(self):
2024-10-22 22:59:29 +08:00
self.config = getConfig()
# self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz")
self.load_model()
2024-12-25 16:04:39 +08:00
self.setup_voices()
2024-10-21 18:29:57 +08:00
2024-10-22 22:59:29 +08:00
def load_model(self):
2024-12-25 16:04:39 +08:00
self.vocoder = load_vocoder(vocoder_name=self.config.vocoder_name,
is_local=True,
local_path=self.config.vocoder_local_path)
2024-10-22 22:59:29 +08:00
# load models
ckpt_file = ''
if self.config.modelname == "F5-TTS":
model_cls = DiT
2024-12-25 16:04:39 +08:00
model_cfg = dict(dim=1024, depth=22, heads=16,
ff_mult=2, text_dim=512, conv_layers=4)
if self.config.vocoder_name == "vocos":
2024-10-22 22:59:29 +08:00
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
2024-12-25 16:04:39 +08:00
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif self.config.vocoder_name == "bigvgan":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base_bigvgan"
ckpt_step = 1250000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
2024-10-22 22:59:29 +08:00
elif self.config.modelname == "E2-TTS":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if ckpt_file == "":
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
self.model = load_model(model_cls, model_cfg, ckpt_file,
2024-12-25 16:04:39 +08:00
mel_spec_type=self.config.vocoder_name,
vocab_file=self.config.vocab_file)
2024-10-23 18:38:56 +08:00
self.model = self.model.to(self.config.device)
2024-10-21 18:29:57 +08:00
2024-12-25 16:04:39 +08:00
def f5tts_infer(self, ref_audio, ref_text, gen_text):
audio, final_sample_rate, spectragram = \
infer_process(ref_audio,
ref_text,
gen_text,
self.model,
self.vocoder,
mel_spec_type=self.config.vocoder_name,
speed=self.config.speed or speed)
return {
'audio':audio,
'sample_rate':final_sample_rate,
'spectragram':spectragram
}
def get_speakers(self):
t = [{'value':s, 'text':s} for s in self.speakers.keys() ]
t.append({'value':'main', 'text':'main'})
return t
async def split_text(self, text_gen, speaker):
2024-10-22 22:59:29 +08:00
reg1 = r"(?=\[\w+\])"
2024-12-25 16:04:39 +08:00
lang = await awaitify(detect_language)(text_gen)
if self.config.language.get(lang):
reg1 = r"{}".format(self.config.language.get(lang).sentence_splitter)
if lang == 'zh':
text_gen = await awaitify(cn2an.transform)(text_gen, 'an2cn')
2024-10-22 22:59:29 +08:00
chunks = re.split(reg1, text_gen)
2024-12-25 16:04:39 +08:00
# reg2 = self.config.speaker_match
reg2 = r"\[\[(\w+)\]\]"
ret = []
2024-10-22 22:59:29 +08:00
for text in chunks:
2024-12-25 16:04:39 +08:00
if text == ['\r', '']:
continue
voice = speaker
2024-10-22 22:59:29 +08:00
match = re.match(reg2, text)
if match:
2024-12-25 16:04:39 +08:00
debug(f'{text=}, match {reg2=}')
2024-10-22 22:59:29 +08:00
voice = match[1]
if voice not in self.voices:
voice = "main"
2024-12-25 16:04:39 +08:00
debug(f'{text} inferences with speaker({voice})..{reg2=}')
2024-10-22 22:59:29 +08:00
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = self.voices[voice]["ref_audio"]
ref_text = self.voices[voice]["ref_text"]
2024-12-25 16:04:39 +08:00
ret.append({'text':gen_text, 'ref_audio':ref_audio, 'ref_text':ref_text})
return ret
async def infer_stream(self, prompt, speaker):
async for a in self._inference_stream(prompt, speaker):
wavdata = a['audio']
samplerate = a['sample_rate']
b = await async_write_wav_buffer(wavdata, 1, samplerate)
yield b
async def _inference_stream(self, prompt, speaker):
text_gen = prompt
chunks = await self.split_text(prompt, speaker)
for chunk in chunks:
gen_text = chunk['text']
ref_audio = chunk['ref_audio']
ref_text = chunk['ref_text']
infer = awaitify(self.f5tts_infer)
try:
d = await infer(ref_audio, ref_text, gen_text)
yield d
except:
pass
def setup_voices(self):
config = getConfig()
d = None
with codecs.open(config.speakers_file, 'r', 'utf-8') as f:
b = f.read()
self.speakers = json.loads(b)
ref_audio, ref_text = preprocess_ref_audio_text(config.ref_audio, config.ref_text)
self.voices = {
"main":{
'ref_text':ref_text,
'ref_audio':ref_audio
2024-10-23 18:38:56 +08:00
}
}
2024-12-25 16:04:39 +08:00
for k,v in self.speakers.items():
ref_audio, ref_text = preprocess_ref_audio_text(v['ref_audio'], v['ref_text'])
self.voices[k] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
2024-10-23 18:38:56 +08:00
2024-12-25 16:04:39 +08:00
async def add_voice(self, speaker, ref_audio, ref_text):
debug(f'{speaker=}, {ref_audio=}, {ref_text=}');
config = getConfig()
ref_audio = FileStorage().realPath(ref_audio)
self.speakers[speaker] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
f = awaitify(preprocess_ref_audio_text)
ref_audio, ref_text = await f(ref_audio, ref_text)
self.voices[speaker] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
with codecs.open(config.speakers_file, 'w', 'utf-8') as f:
f.write(json.dumps(self.speakers, indent=4))
return None
async def _inference(self, prompt, speaker):
2024-10-23 18:38:56 +08:00
generated_audio_segments = []
remove_silence = self.config.remove_silence or False
final_sample_rate = 24000
2024-12-25 16:04:39 +08:00
async for d in self._inference_stream(prompt, speaker):
audio = d['audio']
final_sample_rate = d['sample_rate']
generated_audio_segments.append(audio)
2024-10-22 22:59:29 +08:00
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
2024-12-25 16:04:39 +08:00
return await async_write_wav_buffer(final_wave, 1, final_sample_rate)
def UiError(title="出错", message="出错啦", timeout=5):
return {
"widgettype":"Error",
"options":{
"author":"tr",
"timeout":timeout,
"cwidth":15,
"cheight":10,
"title":title,
"auto_open":True,
"auto_dismiss":True,
"auto_destroy":True,
"message":message
}
}
def UiMessage(title="消息", message="后台消息", timeout=5):
return {
"widgettype":"Message",
"options":{
"author":"tr",
"timeout":timeout,
"cwidth":15,
"cheight":10,
"title":title,
"auto_open":True,
"auto_dismiss":True,
"auto_destroy":True,
"message":message
}
}
def test1():
sleep(36000)
return {}
2024-10-22 22:59:29 +08:00
2024-12-19 16:50:53 +08:00
def init():
g = ServerEnv()
f5 = F5TTS()
2024-12-25 16:04:39 +08:00
g.infer_stream = f5.infer_stream
g.get_speakers = f5.get_speakers
g.infer = f5._inference
g.test1 = awaitify(test1)
g.add_voice = f5.add_voice
g.UiError = UiError
g.UiMessage = UiMessage
2024-12-19 16:50:53 +08:00
2024-10-21 18:29:57 +08:00
if __name__ == '__main__':
2024-12-19 16:50:53 +08:00
webapp(init)