f5tts/app/f5tts.py
2025-06-22 12:39:47 +08:00

350 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import io
import base64
import sys
import asyncio
import codecs
from traceback import format_exc
import re
import numpy as np
import soundfile as sf
# import tomli
from cached_path import cached_path
from appPublic.textsplit import split_text_with_dialog_preserved
from appPublic.uniqueID import getID
from ahserver.serverenv import get_serverenv
from filetxt.loader import fileloader
import pycld2 as cld
import cn2an
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
)
import json
from time import time, sleep
from appPublic.dictObject import DictObject
from appPublic.folderUtils import temp_file
from appPublic.jsonConfig import getConfig
from appPublic.worker import awaitify
from appPublic.uniqueID import getID
from appPublic.log import debug, info
from appPublic.background import Background
from ahserver.webapp import webapp
from ahserver.serverenv import ServerEnv
from ahserver.filestorage import FileStorage
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
def audio_ndarray_to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> str:
# 如果是单通道,确保 shape 为 (samples, 1)
if waveform.ndim == 1:
waveform = waveform.reshape(-1, 1)
# 写入内存 bufferWAV 格式)
buffer = io.BytesIO()
sf.write(buffer, waveform, samplerate=sample_rate, format='WAV')
buffer.seek(0)
# base64 编码
b64_audio = base64.b64encode(buffer.read()).decode('utf-8')
return b64_audio
def write_wav_buffer(wav, nchannels, framerate):
fs = FileStorage()
fn = fs._name2path(f'{getID()}.wav', userid='tmp')
os.makedirs(os.path.dirname(fn))
debug(fn)
with open(fn, "wb") as f:
sf.write(f.name, wav, framerate)
return fs.webpath(fn)
async_write_wav_buffer = awaitify(write_wav_buffer)
def detect_language(txt):
isReliable, textBytesFound, details = cld.detect(txt)
debug(f' detect_language():{isReliable=}, {textBytesFound=}, {details=} ')
return details[0][1]
class F5TTS:
def __init__(self):
self.config = getConfig()
# self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz")
self.load_model()
self.setup_voices()
def load_model(self):
self.vocoder = load_vocoder(vocoder_name=self.config.vocoder_name,
is_local=True,
local_path=self.config.vocoder_local_path)
# load models
ckpt_file = ''
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16,
ff_mult=2, text_dim=512, conv_layers=4)
ckpt_file = self.config.ckpts_path
self.model = load_model(model_cls, model_cfg, ckpt_file,
mel_spec_type=self.config.vocoder_name,
vocab_file=self.config.vocab_file)
self.model = self.model.to(self.config.device)
self.lock = asyncio.Lock()
def f5tts_infer(self, ref_audio, ref_text, gen_text, speed_factor):
audio, final_sample_rate, spectragram = \
infer_process(ref_audio,
ref_text,
gen_text,
self.model,
self.vocoder,
mel_spec_type=self.config.vocoder_name,
speed=self.config.speed or speed)
if audio is not None:
audio = self.speed_convert(audio, speed_factor)
else:
return None
debug(f'audio shape {audio.shape}, {gen_text=}')
return {
'text': gen_text,
'audio':audio,
'sample_rate':final_sample_rate
}
def speed_convert(self, output_audio_np, speed_factor):
original_len = len(output_audio_np)
speed_factor = max(0.1, min(speed_factor, 5.0))
target_len = int(
original_len / speed_factor
) # Target length based on speed_factor
if (
target_len != original_len and target_len > 0
): # Only interpolate if length changes and is valid
x_original = np.arange(original_len)
x_resampled = np.linspace(0, original_len - 1, target_len)
output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
output_audio_np = output_audio_np.astype(np.float32)
return output_audio_np
def get_speakers(self):
t = [{'value':s, 'text':s} for s in self.speakers.keys() ]
t.append({'value':'main', 'text':'main'})
return t
async def split_text(self, text_gen, speaker):
chunks = split_text_with_dialog_preserved(text_gen)
debug(f'{len(chunks)=}')
# reg2 = self.config.speaker_match
reg2 = r"\[\[(\w+)\]\]"
ret = []
for text in chunks:
if text == ['\r', '']:
continue
lang = await awaitify(detect_language)(text)
if lang == 'zh':
text = await awaitify(cn2an.transform)(text, 'an2cn')
voice = speaker
match = re.match(reg2, text)
if match:
voice = match[1]
if voice not in self.voices:
voice = speaker
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = self.voices[voice]["ref_audio"]
ref_text = self.voices[voice]["ref_text"]
ret.append({'text':gen_text, 'ref_audio':ref_audio, 'ref_text':ref_text})
return ret
async def infer_stream(self, prompt, speaker, speed_factor=1.0):
async for a in self._inference_stream(prompt, speaker, speed_factor=speed_factor):
wavdata = a['audio']
samplerate = a['sample_rate']
b = await async_write_wav_buffer(wavdata, 1, samplerate)
yield b
async def _inference_stream(self, prompt, speaker, speed_factor=1.0):
text_gen = prompt
chunks = await self.split_text(prompt, speaker)
debug(f'{len(chunks)=}')
for chunk in chunks:
gen_text = chunk['text']
ref_audio = chunk['ref_audio']
ref_text = chunk['ref_text']
infer = awaitify(self.f5tts_infer)
try:
d = await infer(ref_audio, ref_text, gen_text, speed_factor)
if d is not None:
yield d
except:
debug(f'{gen_text=} inference error\n{format_exc()}')
async def inference_stream(self, prompt, speaker, speed_factor=1.0):
total_duration = 0
async for d in self._inference_stream(prompt, speaker, speed_factor=speed_factor):
sampels = d['audio'].shape[0]
duration = samples / d['sample_rate']
audio_b64=audio_ndarray_to_base64(d['audio'], d['sample_rate'])
d['audio'] = audio_b64
d['duration'] = duration
d['done'] = False
txt = json.dumps(d, ensure_ascii=False)
yield txt + '\n'
d = {
'done': True,
'duration': total_duration
}
txt = json.dumps(d, ensure_ascii=False)
yield txt + '\n'
def setup_voices(self):
config = getConfig()
workdir = config.workdir
print('workdir=', workdir)
d = None
with codecs.open(config.speakers_file, 'r', 'utf-8') as f:
b = f.read()
self.speakers = json.loads(b)
fn = f'{workdir}/samples/{config.ref_audio}'
ref_audio, ref_text = preprocess_ref_audio_text(fn,
config.ref_text)
self.voices = {
"main":{
'ref_text':ref_text,
'ref_audio':ref_audio
}
}
for k,v in self.speakers.items():
fn = f'{workdir}/samples/{v["ref_audio"]}'
ref_audio, ref_text = preprocess_ref_audio_text(fn,
v['ref_text'])
self.voices[k] = {
'ref_text':ref_text,
'ref_audio':ref_audio
}
def copyfile(self, src, dest):
with open(src, 'rb') as f:
b = f.read()
with open(dest, 'wb') as f1:
f1.write(b)
async def add_voice(self, speaker, ref_audio, ref_text):
config = getConfig()
ref_audio = FileStorage().realPath(ref_audio)
workdir = config.workdir
filename = f'{getID()}.wav'
fn = f'{workdir}/samples/{filename}'
await awaitify(self.copyfile)(ref_audio, fn)
os.unlink(ref_adio)
self.speakers[speaker] = {
'ref_text':ref_text,
'ref_audio':filename
}
f = awaitify(preprocess_ref_audio_text)
ref_audio, ref_text = await f(ref_audio, ref_text)
self.voices[speaker] = {
'ref_text':ref_text,
'ref_audio':filename
}
with codecs.open(config.speakers_file, 'w', 'utf-8') as f:
f.write(json.dumps(self.speakers, indent=4, ensure_ascii=False))
return None
async def _inference(self, prompt, speaker, speed_factor=1.0):
generated_audio_segments = []
remove_silence = self.config.remove_silence or False
final_sample_rate = 16000
async for d in self._inference_stream(prompt,
speaker,
speed_factor=speed_factor):
audio = d.get('audio', None)
if audio is None:
debug(f'audio is none, {d=}')
continue
final_sample_rate = d['sample_rate']
generated_audio_segments.append(audio)
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
debug(f'{prompt=}, {final_sample_rate=}')
return await async_write_wav_buffer(final_wave, 1, final_sample_rate)
else:
debug(f'{prompt=} not audio generated')
def UiError(title="出错", message="出错啦", timeout=5):
return {
"widgettype":"Error",
"options":{
"author":"tr",
"timeout":timeout,
"cwidth":15,
"cheight":10,
"title":title,
"auto_open":True,
"auto_dismiss":True,
"auto_destroy":True,
"message":message
}
}
def UiMessage(title="消息", message="后台消息", timeout=5):
return {
"widgettype":"Message",
"options":{
"author":"tr",
"timeout":timeout,
"cwidth":15,
"cheight":10,
"title":title,
"auto_open":True,
"auto_dismiss":True,
"auto_destroy":True,
"message":message
}
}
def test1():
sleep(36000)
return {}
f5 = None
def init():
global f5
g = ServerEnv()
f5 = F5TTS()
g.tts_engine = f5
g.infer_stream = f5.infer_stream
g.inference_stream = f5.inference_stream
g.get_speakers = f5.get_speakers
g.infer = f5._inference
g.test1 = awaitify(test1)
g.add_voice = f5.add_voice
g.UiError = UiError
g.filelaoder = fileloader
g.UiMessage = UiMessage
if __name__ == '__main__':
webapp(init)