348 lines
9.5 KiB
Python
348 lines
9.5 KiB
Python
import os
|
||
import io
|
||
import base64
|
||
import sys
|
||
import asyncio
|
||
import codecs
|
||
from traceback import format_exc
|
||
import re
|
||
|
||
import numpy as np
|
||
import soundfile as sf
|
||
# import tomli
|
||
from cached_path import cached_path
|
||
from appPublic.textsplit import split_text_with_dialog_preserved
|
||
from appPublic.uniqueID import getID
|
||
from ahserver.serverenv import get_serverenv
|
||
from filetxt.loader import fileloader
|
||
import pycld2 as cld
|
||
import cn2an
|
||
|
||
from f5_tts.model import DiT, UNetT
|
||
from f5_tts.infer.utils_infer import (
|
||
mel_spec_type,
|
||
target_rms,
|
||
cross_fade_duration,
|
||
nfe_step,
|
||
cfg_strength,
|
||
sway_sampling_coef,
|
||
speed,
|
||
fix_duration,
|
||
infer_process,
|
||
load_model,
|
||
load_vocoder,
|
||
preprocess_ref_audio_text,
|
||
remove_silence_for_generated_wav,
|
||
)
|
||
|
||
import json
|
||
from time import time, sleep
|
||
from appPublic.dictObject import DictObject
|
||
from appPublic.folderUtils import temp_file
|
||
from appPublic.jsonConfig import getConfig
|
||
from appPublic.worker import awaitify
|
||
from appPublic.uniqueID import getID
|
||
from appPublic.log import debug, info
|
||
from appPublic.background import Background
|
||
from ahserver.webapp import webapp
|
||
from ahserver.serverenv import ServerEnv
|
||
from ahserver.filestorage import FileStorage
|
||
|
||
n_mel_channels = 100
|
||
hop_length = 256
|
||
target_rms = 0.1
|
||
nfe_step = 32 # 16, 32
|
||
cfg_strength = 2.0
|
||
ode_method = "euler"
|
||
sway_sampling_coef = -1.0
|
||
speed = 1.0
|
||
|
||
def audio_ndarray_to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> str:
|
||
# 如果是单通道,确保 shape 为 (samples, 1)
|
||
if waveform.ndim == 1:
|
||
waveform = waveform.reshape(-1, 1)
|
||
|
||
# 写入内存 buffer(WAV 格式)
|
||
buffer = io.BytesIO()
|
||
sf.write(buffer, waveform, samplerate=sample_rate, format='WAV')
|
||
buffer.seek(0)
|
||
|
||
# base64 编码
|
||
b64_audio = base64.b64encode(buffer.read()).decode('utf-8')
|
||
return b64_audio
|
||
|
||
def write_wav_buffer(wav, nchannels, framerate):
|
||
fs = FileStorage()
|
||
fn = fs._name2path(f'{getID()}.wav', userid='tmp')
|
||
os.makedirs(os.path.dirname(fn))
|
||
debug(fn)
|
||
with open(fn, "wb") as f:
|
||
sf.write(f.name, wav, framerate)
|
||
return fs.webpath(fn)
|
||
|
||
async_write_wav_buffer = awaitify(write_wav_buffer)
|
||
|
||
def detect_language(txt):
|
||
isReliable, textBytesFound, details = cld.detect(txt)
|
||
debug(f' detect_language():{isReliable=}, {textBytesFound=}, {details=} ')
|
||
return details[0][1]
|
||
|
||
class F5TTS:
|
||
def __init__(self):
|
||
self.config = getConfig()
|
||
# self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz")
|
||
self.load_model()
|
||
self.setup_voices()
|
||
|
||
def load_model(self):
|
||
self.vocoder = load_vocoder(vocoder_name=self.config.vocoder_name,
|
||
is_local=True,
|
||
local_path=self.config.vocoder_local_path)
|
||
|
||
# load models
|
||
ckpt_file = ''
|
||
model_cls = DiT
|
||
model_cfg = dict(dim=1024, depth=22, heads=16,
|
||
ff_mult=2, text_dim=512, conv_layers=4)
|
||
ckpt_file = self.config.ckpts_path
|
||
self.model = load_model(model_cls, model_cfg, ckpt_file,
|
||
mel_spec_type=self.config.vocoder_name,
|
||
vocab_file=self.config.vocab_file)
|
||
self.model = self.model.to(self.config.device)
|
||
self.lock = asyncio.Lock()
|
||
|
||
def f5tts_infer(self, ref_audio, ref_text, gen_text, speed_factor):
|
||
audio, final_sample_rate, spectragram = \
|
||
infer_process(ref_audio,
|
||
ref_text,
|
||
gen_text,
|
||
self.model,
|
||
self.vocoder,
|
||
mel_spec_type=self.config.vocoder_name,
|
||
speed=self.config.speed or speed)
|
||
if audio is not None:
|
||
audio = self.speed_convert(audio, speed_factor)
|
||
else:
|
||
return None
|
||
debug(f'audio shape {audio.shape}, {gen_text=}')
|
||
return {
|
||
'text': gen_text,
|
||
'audio':audio,
|
||
'sample_rate':final_sample_rate
|
||
}
|
||
|
||
def speed_convert(self, output_audio_np, speed_factor):
|
||
original_len = len(output_audio_np)
|
||
speed_factor = max(0.1, min(speed_factor, 5.0))
|
||
target_len = int(
|
||
original_len / speed_factor
|
||
) # Target length based on speed_factor
|
||
if (
|
||
target_len != original_len and target_len > 0
|
||
): # Only interpolate if length changes and is valid
|
||
x_original = np.arange(original_len)
|
||
x_resampled = np.linspace(0, original_len - 1, target_len)
|
||
output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
|
||
output_audio_np = output_audio_np.astype(np.float32)
|
||
return output_audio_np
|
||
|
||
def get_speakers(self):
|
||
t = [{'value':s, 'text':s} for s in self.speakers.keys() ]
|
||
t.append({'value':'main', 'text':'main'})
|
||
return t
|
||
|
||
async def split_text(self, text_gen, speaker):
|
||
chunks = split_text_with_dialog_preserved(text_gen)
|
||
debug(f'{len(chunks)=}')
|
||
# reg2 = self.config.speaker_match
|
||
reg2 = r"\[\[(\w+)\]\]"
|
||
ret = []
|
||
for text in chunks:
|
||
if text == ['\r', '']:
|
||
continue
|
||
lang = await awaitify(detect_language)(text)
|
||
if lang == 'zh':
|
||
text = await awaitify(cn2an.transform)(text, 'an2cn')
|
||
voice = speaker
|
||
match = re.match(reg2, text)
|
||
if match:
|
||
voice = match[1]
|
||
if voice not in self.voices:
|
||
voice = speaker
|
||
text = re.sub(reg2, "", text)
|
||
gen_text = text.strip()
|
||
ref_audio = self.voices[voice]["ref_audio"]
|
||
ref_text = self.voices[voice]["ref_text"]
|
||
ret.append({'text':gen_text, 'ref_audio':ref_audio, 'ref_text':ref_text})
|
||
return ret
|
||
|
||
async def infer_stream(self, prompt, speaker, speed_factor=1.0):
|
||
async for a in self._inference_stream(prompt, speaker, speed_factor=speed_factor):
|
||
wavdata = a['audio']
|
||
samplerate = a['sample_rate']
|
||
b = await async_write_wav_buffer(wavdata, 1, samplerate)
|
||
yield b
|
||
|
||
async def _inference_stream(self, prompt, speaker, speed_factor=1.0):
|
||
text_gen = prompt
|
||
chunks = await self.split_text(prompt, speaker)
|
||
debug(f'{len(chunks)=}')
|
||
for chunk in chunks:
|
||
gen_text = chunk['text']
|
||
ref_audio = chunk['ref_audio']
|
||
ref_text = chunk['ref_text']
|
||
infer = awaitify(self.f5tts_infer)
|
||
try:
|
||
d = await infer(ref_audio, ref_text, gen_text, speed_factor)
|
||
if d is not None:
|
||
yield d
|
||
except:
|
||
debug(f'{gen_text=} inference error\n{format_exc()}')
|
||
|
||
async def inference_stream(self, prompt, speaker, speed_factor=1.0):
|
||
total_duration = 0
|
||
async for d in self._inference_stream(prompt, speaker, speed_factor=speed_factor):
|
||
sampels = d['audio'].shape[0]
|
||
duration = samples / d['sample_rate']
|
||
audio_b64=audio_ndarray_to_base64(d['audio'], d['sample_rate'])
|
||
d['audio'] = audio_b64
|
||
d['duration'] = duration
|
||
d['done'] = False
|
||
txt = json.dumps(d, ensure_ascii=False)
|
||
yield txt + '\n'
|
||
d = {
|
||
'done': True,
|
||
'duration': total_duration
|
||
}
|
||
txt = json.dumps(d, ensure_ascii=False)
|
||
yield txt + '\n'
|
||
|
||
def setup_voices(self):
|
||
config = getConfig()
|
||
workdir = get_serverenv('workdir')
|
||
d = None
|
||
with codecs.open(config.speakers_file, 'r', 'utf-8') as f:
|
||
b = f.read()
|
||
self.speakers = json.loads(b)
|
||
fn = f'{workdir}/samples/{config.ref_audio}'
|
||
ref_audio, ref_text = preprocess_ref_audio_text(fn,
|
||
config.ref_text)
|
||
self.voices = {
|
||
"main":{
|
||
'ref_text':ref_text,
|
||
'ref_audio':ref_audio
|
||
}
|
||
}
|
||
for k,v in self.speakers.items():
|
||
fn = f'{workdir}/samples/{v["ref_audio"]}'
|
||
ref_audio, ref_text = preprocess_ref_audio_text(fn,
|
||
v['ref_text'])
|
||
self.voices[k] = {
|
||
'ref_text':ref_text,
|
||
'ref_audio':ref_audio
|
||
}
|
||
|
||
def copyfile(self, src, dest):
|
||
with open(src, 'rb') as f:
|
||
b = f.read()
|
||
with open(dest, 'wb') as f1:
|
||
f1.write(b)
|
||
|
||
async def add_voice(self, speaker, ref_audio, ref_text):
|
||
config = getConfig()
|
||
ref_audio = FileStorage().realPath(ref_audio)
|
||
workdir = get_serverenv('workdir')
|
||
filename = f'{getID()}.wav'
|
||
fn = f'{workdir}/samples/{filename}'
|
||
await awaitify(self.copyfile)(ref_audio, fn)
|
||
os.unlink(ref_adio)
|
||
self.speakers[speaker] = {
|
||
'ref_text':ref_text,
|
||
'ref_audio':filename
|
||
}
|
||
f = awaitify(preprocess_ref_audio_text)
|
||
ref_audio, ref_text = await f(ref_audio, ref_text)
|
||
self.voices[speaker] = {
|
||
'ref_text':ref_text,
|
||
'ref_audio':filename
|
||
}
|
||
with codecs.open(config.speakers_file, 'w', 'utf-8') as f:
|
||
f.write(json.dumps(self.speakers, indent=4, ensure_ascii=False))
|
||
return None
|
||
|
||
async def _inference(self, prompt, speaker, speed_factor=1.0):
|
||
generated_audio_segments = []
|
||
remove_silence = self.config.remove_silence or False
|
||
final_sample_rate = 16000
|
||
async for d in self._inference_stream(prompt,
|
||
speaker,
|
||
speed_factor=speed_factor):
|
||
audio = d.get('audio', None)
|
||
if audio is None:
|
||
debug(f'audio is none, {d=}')
|
||
continue
|
||
final_sample_rate = d['sample_rate']
|
||
generated_audio_segments.append(audio)
|
||
|
||
if generated_audio_segments:
|
||
final_wave = np.concatenate(generated_audio_segments)
|
||
debug(f'{prompt=}, {final_sample_rate=}')
|
||
return await async_write_wav_buffer(final_wave, 1, final_sample_rate)
|
||
else:
|
||
debug(f'{prompt=} not audio generated')
|
||
|
||
def UiError(title="出错", message="出错啦", timeout=5):
|
||
return {
|
||
"widgettype":"Error",
|
||
"options":{
|
||
"author":"tr",
|
||
"timeout":timeout,
|
||
"cwidth":15,
|
||
"cheight":10,
|
||
"title":title,
|
||
"auto_open":True,
|
||
"auto_dismiss":True,
|
||
"auto_destroy":True,
|
||
"message":message
|
||
}
|
||
}
|
||
|
||
|
||
def UiMessage(title="消息", message="后台消息", timeout=5):
|
||
return {
|
||
"widgettype":"Message",
|
||
"options":{
|
||
"author":"tr",
|
||
"timeout":timeout,
|
||
"cwidth":15,
|
||
"cheight":10,
|
||
"title":title,
|
||
"auto_open":True,
|
||
"auto_dismiss":True,
|
||
"auto_destroy":True,
|
||
"message":message
|
||
}
|
||
}
|
||
|
||
def test1():
|
||
sleep(36000)
|
||
return {}
|
||
|
||
f5 = None
|
||
def init():
|
||
global f5
|
||
g = ServerEnv()
|
||
f5 = F5TTS()
|
||
g.infer_stream = f5.infer_stream
|
||
g.inference_stream = f5.inference_stream
|
||
g.get_speakers = f5.get_speakers
|
||
g.infer = f5._inference
|
||
g.test1 = awaitify(test1)
|
||
g.add_voice = f5.add_voice
|
||
g.UiError = UiError
|
||
g.filelaoder = fileloader
|
||
g.UiMessage = UiMessage
|
||
|
||
if __name__ == '__main__':
|
||
webapp(init)
|