f5tts/app/f5tts.py
2024-12-25 16:04:39 +08:00

287 lines
7.7 KiB
Python

import os
import sys
import argparse
import codecs
import re
from pathlib import Path
from functools import partial
import numpy as np
import soundfile as sf
# import tomli
from cached_path import cached_path
import pycld2 as cld
import cn2an
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
)
import json
from time import time, sleep
from appPublic.dictObject import DictObject
from appPublic.folderUtils import temp_file
from appPublic.jsonConfig import getConfig
from appPublic.worker import awaitify
from appPublic.uniqueID import getID
from appPublic.log import debug, info
from appPublic.background import Background
from ahserver.webapp import webapp
from ahserver.serverenv import ServerEnv
from ahserver.filestorage import FileStorage
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
def write_wav_buffer(wav, nchannels, framerate):
fs = FileStorage()
fn = fs._name2path(f'{getID()}.wav', userid='tmp')
os.makedirs(os.path.dirname(fn))
debug(fn)
with open(fn, "wb") as f:
sf.write(f.name, wav, framerate)
return fs.webpath(fn)
async_write_wav_buffer = awaitify(write_wav_buffer)
def detect_language(txt):
isReliable, textBytesFound, details = cld.detect(txt)
debug(f' detect_language():{isReliable=}, {textBytesFound=}, {details=} ')
return details[0][1]
class F5TTS:
def __init__(self):
self.config = getConfig()
# self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz")
self.load_model()
self.setup_voices()
def load_model(self):
self.vocoder = load_vocoder(vocoder_name=self.config.vocoder_name,
is_local=True,
local_path=self.config.vocoder_local_path)
# load models
ckpt_file = ''
if self.config.modelname == "F5-TTS":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16,
ff_mult=2, text_dim=512, conv_layers=4)
if self.config.vocoder_name == "vocos":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif self.config.vocoder_name == "bigvgan":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base_bigvgan"
ckpt_step = 1250000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
elif self.config.modelname == "E2-TTS":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if ckpt_file == "":
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
self.model = load_model(model_cls, model_cfg, ckpt_file,
mel_spec_type=self.config.vocoder_name,
vocab_file=self.config.vocab_file)
self.model = self.model.to(self.config.device)
def f5tts_infer(self, ref_audio, ref_text, gen_text):
audio, final_sample_rate, spectragram = \
infer_process(ref_audio,
ref_text,
gen_text,
self.model,
self.vocoder,
mel_spec_type=self.config.vocoder_name,
speed=self.config.speed or speed)
return {
'audio':audio,
'sample_rate':final_sample_rate,
'spectragram':spectragram
}
def get_speakers(self):
t = [{'value':s, 'text':s} for s in self.speakers.keys() ]
t.append({'value':'main', 'text':'main'})
return t
async def split_text(self, text_gen, speaker):
reg1 = r"(?=\[\w+\])"
lang = await awaitify(detect_language)(text_gen)
if self.config.language.get(lang):
reg1 = r"{}".format(self.config.language.get(lang).sentence_splitter)
if lang == 'zh':
text_gen = await awaitify(cn2an.transform)(text_gen, 'an2cn')
chunks = re.split(reg1, text_gen)
# reg2 = self.config.speaker_match
reg2 = r"\[\[(\w+)\]\]"
ret = []
for text in chunks:
if text == ['\r', '']:
continue
voice = speaker
match = re.match(reg2, text)
if match:
debug(f'{text=}, match {reg2=}')
voice = match[1]
if voice not in self.voices:
voice = "main"
debug(f'{text} inferences with speaker({voice})..{reg2=}')
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = self.voices[voice]["ref_audio"]
ref_text = self.voices[voice]["ref_text"]
ret.append({'text':gen_text, 'ref_audio':ref_audio, 'ref_text':ref_text})
return ret
async def infer_stream(self, prompt, speaker):
async for a in self._inference_stream(prompt, speaker):
wavdata = a['audio']
samplerate = a['sample_rate']
b = await async_write_wav_buffer(wavdata, 1, samplerate)
yield b
async def _inference_stream(self, prompt, speaker):
text_gen = prompt
chunks = await self.split_text(prompt, speaker)
for chunk in chunks:
gen_text = chunk['text']
ref_audio = chunk['ref_audio']
ref_text = chunk['ref_text']
infer = awaitify(self.f5tts_infer)
try:
d = await infer(ref_audio, ref_text, gen_text)
yield d
except:
pass
def setup_voices(self):
config = getConfig()
d = None
with codecs.open(config.speakers_file, 'r', 'utf-8') as f:
b = f.read()
self.speakers = json.loads(b)
ref_audio, ref_text = preprocess_ref_audio_text(config.ref_audio, config.ref_text)
self.voices = {
"main":{
'ref_text':ref_text,
'ref_audio':ref_audio
}
}
for k,v in self.speakers.items():
ref_audio, ref_text = preprocess_ref_audio_text(v['ref_audio'], v['ref_text'])
self.voices[k] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
async def add_voice(self, speaker, ref_audio, ref_text):
debug(f'{speaker=}, {ref_audio=}, {ref_text=}');
config = getConfig()
ref_audio = FileStorage().realPath(ref_audio)
self.speakers[speaker] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
f = awaitify(preprocess_ref_audio_text)
ref_audio, ref_text = await f(ref_audio, ref_text)
self.voices[speaker] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
with codecs.open(config.speakers_file, 'w', 'utf-8') as f:
f.write(json.dumps(self.speakers, indent=4))
return None
async def _inference(self, prompt, speaker):
generated_audio_segments = []
remove_silence = self.config.remove_silence or False
final_sample_rate = 24000
async for d in self._inference_stream(prompt, speaker):
audio = d['audio']
final_sample_rate = d['sample_rate']
generated_audio_segments.append(audio)
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
return await async_write_wav_buffer(final_wave, 1, final_sample_rate)
def UiError(title="出错", message="出错啦", timeout=5):
return {
"widgettype":"Error",
"options":{
"author":"tr",
"timeout":timeout,
"cwidth":15,
"cheight":10,
"title":title,
"auto_open":True,
"auto_dismiss":True,
"auto_destroy":True,
"message":message
}
}
def UiMessage(title="消息", message="后台消息", timeout=5):
return {
"widgettype":"Message",
"options":{
"author":"tr",
"timeout":timeout,
"cwidth":15,
"cheight":10,
"title":title,
"auto_open":True,
"auto_dismiss":True,
"auto_destroy":True,
"message":message
}
}
def test1():
sleep(36000)
return {}
def init():
g = ServerEnv()
f5 = F5TTS()
g.infer_stream = f5.infer_stream
g.get_speakers = f5.get_speakers
g.infer = f5._inference
g.test1 = awaitify(test1)
g.add_voice = f5.add_voice
g.UiError = UiError
g.UiMessage = UiMessage
if __name__ == '__main__':
webapp(init)