f5tts/app/f5tts.py
2025-05-11 14:12:37 +00:00

297 lines
8.1 KiB
Python

import os
import sys
import asyncio
import codecs
from traceback import format_exc
import re
import numpy as np
import soundfile as sf
# import tomli
from cached_path import cached_path
import pycld2 as cld
import cn2an
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
)
import json
from time import time, sleep
from appPublic.dictObject import DictObject
from appPublic.folderUtils import temp_file
from appPublic.jsonConfig import getConfig
from appPublic.worker import awaitify
from appPublic.uniqueID import getID
from appPublic.log import debug, info
from appPublic.background import Background
from ahserver.webapp import webapp
from ahserver.serverenv import ServerEnv
from ahserver.filestorage import FileStorage
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
def write_wav_buffer(wav, nchannels, framerate):
fs = FileStorage()
fn = fs._name2path(f'{getID()}.wav', userid='tmp')
os.makedirs(os.path.dirname(fn))
debug(fn)
with open(fn, "wb") as f:
sf.write(f.name, wav, framerate)
return fs.webpath(fn)
async_write_wav_buffer = awaitify(write_wav_buffer)
def detect_language(txt):
isReliable, textBytesFound, details = cld.detect(txt)
debug(f' detect_language():{isReliable=}, {textBytesFound=}, {details=} ')
return details[0][1]
class F5TTS:
def __init__(self):
self.config = getConfig()
# self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz")
self.load_model()
self.setup_voices()
def load_model(self):
self.vocoder = load_vocoder(vocoder_name=self.config.vocoder_name,
is_local=True,
local_path=self.config.vocoder_local_path)
# load models
ckpt_file = ''
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16,
ff_mult=2, text_dim=512, conv_layers=4)
ckpt_file = self.config.ckpts_path
self.model = load_model(model_cls, model_cfg, ckpt_file,
mel_spec_type=self.config.vocoder_name,
vocab_file=self.config.vocab_file)
self.model = self.model.to(self.config.device)
self.lock = asyncio.Lock()
def f5tts_infer(self, ref_audio, ref_text, gen_text, speed_factor):
audio, final_sample_rate, spectragram = \
infer_process(ref_audio,
ref_text,
gen_text,
self.model,
self.vocoder,
mel_spec_type=self.config.vocoder_name,
speed=self.config.speed or speed)
if audio is not None:
audio = self.speed_convert(audio, speed_factor)
else:
return None
debug(f'audio shape {audio.shape}, {gen_text=}')
return {
'text': gen_text,
'audio':audio,
'sample_rate':final_sample_rate,
'spectragram':spectragram
}
def speed_convert(self, output_audio_np, speed_factor):
original_len = len(output_audio_np)
speed_factor = max(0.1, min(speed_factor, 5.0))
target_len = int(
original_len / speed_factor
) # Target length based on speed_factor
if (
target_len != original_len and target_len > 0
): # Only interpolate if length changes and is valid
x_original = np.arange(original_len)
x_resampled = np.linspace(0, original_len - 1, target_len)
output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
output_audio_np = output_audio_np.astype(np.float32)
return output_audio_np
def get_speakers(self):
t = [{'value':s, 'text':s} for s in self.speakers.keys() ]
t.append({'value':'main', 'text':'main'})
return t
async def split_text(self, text_gen, speaker):
reg1 = r"(?=\[\w+\])"
lang = await awaitify(detect_language)(text_gen)
if self.config.language.get(lang):
reg1 = r"{}".format(self.config.language.get(lang).sentence_splitter)
if lang == 'zh':
text_gen = await awaitify(cn2an.transform)(text_gen, 'an2cn')
chunks = re.split(reg1, text_gen)
# reg2 = self.config.speaker_match
reg2 = r"\[\[(\w+)\]\]"
ret = []
for text in chunks:
if text == ['\r', '']:
continue
voice = speaker
match = re.match(reg2, text)
if match:
debug(f'{text=}, match {reg2=}')
voice = match[1]
if voice not in self.voices:
voice = "main"
debug(f'{text} inferences with speaker({voice})..{reg2=}')
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = self.voices[voice]["ref_audio"]
ref_text = self.voices[voice]["ref_text"]
ret.append({'text':gen_text, 'ref_audio':ref_audio, 'ref_text':ref_text})
return ret
async def infer_stream(self, prompt, speaker, speed_factor=1.0):
async for a in self._inference_stream(prompt, speaker, speed_factor=speed_factor):
wavdata = a['audio']
samplerate = a['sample_rate']
b = await async_write_wav_buffer(wavdata, 1, samplerate)
yield b
async def _inference_stream(self, prompt, speaker, speed_factor=1.0):
text_gen = prompt
chunks = await self.split_text(prompt, speaker)
for chunk in chunks:
gen_text = chunk['text']
ref_audio = chunk['ref_audio']
ref_text = chunk['ref_text']
infer = awaitify(self.f5tts_infer)
try:
d = await infer(ref_audio, ref_text, gen_text, speed_factor)
if d is not None:
yield d
except:
debug(f'{gen_text=} inference error\n{format_exc()}')
def setup_voices(self):
config = getConfig()
d = None
with codecs.open(config.speakers_file, 'r', 'utf-8') as f:
b = f.read()
self.speakers = json.loads(b)
ref_audio, ref_text = preprocess_ref_audio_text(config.ref_audio, config.ref_text)
self.voices = {
"main":{
'ref_text':ref_text,
'ref_audio':ref_audio
}
}
for k,v in self.speakers.items():
ref_audio, ref_text = preprocess_ref_audio_text(v['ref_audio'], v['ref_text'])
self.voices[k] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
async def add_voice(self, speaker, ref_audio, ref_text):
debug(f'{speaker=}, {ref_audio=}, {ref_text=}');
config = getConfig()
ref_audio = FileStorage().realPath(ref_audio)
self.speakers[speaker] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
f = awaitify(preprocess_ref_audio_text)
ref_audio, ref_text = await f(ref_audio, ref_text)
self.voices[speaker] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
with codecs.open(config.speakers_file, 'w', 'utf-8') as f:
f.write(json.dumps(self.speakers, indent=4))
return None
async def _inference(self, prompt, speaker, speed_factor=1.0):
generated_audio_segments = []
remove_silence = self.config.remove_silence or False
final_sample_rate = 16000
async for d in self._inference_stream(prompt,
speaker,
speed_factor=speed_factor):
audio = d.get('audio', None)
if audio is None:
debug(f'audio is none, {d=}')
continue
final_sample_rate = d['sample_rate']
generated_audio_segments.append(audio)
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
debug(f'{prompt=}, {final_sample_rate=}')
return await async_write_wav_buffer(final_wave, 1, final_sample_rate)
else:
debug(f'{prompt=} not audio generated')
def UiError(title="出错", message="出错啦", timeout=5):
return {
"widgettype":"Error",
"options":{
"author":"tr",
"timeout":timeout,
"cwidth":15,
"cheight":10,
"title":title,
"auto_open":True,
"auto_dismiss":True,
"auto_destroy":True,
"message":message
}
}
def UiMessage(title="消息", message="后台消息", timeout=5):
return {
"widgettype":"Message",
"options":{
"author":"tr",
"timeout":timeout,
"cwidth":15,
"cheight":10,
"title":title,
"auto_open":True,
"auto_dismiss":True,
"auto_destroy":True,
"message":message
}
}
def test1():
sleep(36000)
return {}
def init():
g = ServerEnv()
f5 = F5TTS()
g.infer_stream = f5.infer_stream
g.inference_stream = f5._inference_stream
g.get_speakers = f5.get_speakers
g.infer = f5._inference
g.test1 = awaitify(test1)
g.add_voice = f5.add_voice
g.UiError = UiError
g.UiMessage = UiMessage
if __name__ == '__main__':
webapp(init)