151 lines
4.1 KiB
Python
151 lines
4.1 KiB
Python
import sys
|
|
import argparse
|
|
import codecs
|
|
import re
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
# import tomli
|
|
from cached_path import cached_path
|
|
|
|
from f5_tts.model import DiT, UNetT
|
|
from f5_tts.model.utils_infer import (
|
|
load_vocoder,
|
|
load_model,
|
|
preprocess_ref_audio_text,
|
|
infer_process,
|
|
remove_silence_for_generated_wav,
|
|
)
|
|
|
|
import os
|
|
import json
|
|
from time import time
|
|
from appPublic.dictObject import DictObject
|
|
from appPublic.folderUtils import temp_file
|
|
from appPublic.jsonConfig import getConfig
|
|
from appPublic.worker import awaitify
|
|
from ahserver.webapp import webapp
|
|
from ahserver.serverEnv import ServerEnv
|
|
|
|
n_mel_channels = 100
|
|
hop_length = 256
|
|
target_rms = 0.1
|
|
nfe_step = 32 # 16, 32
|
|
cfg_strength = 2.0
|
|
ode_method = "euler"
|
|
sway_sampling_coef = -1.0
|
|
speed = 1.0
|
|
|
|
class F5TTS:
|
|
def __init__(self):
|
|
self.config = getConfig()
|
|
# self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz")
|
|
self.load_model()
|
|
self.setup_voice()
|
|
|
|
def load_model(self):
|
|
# load models
|
|
ckpt_file = ''
|
|
if self.config.modelname == "F5-TTS":
|
|
model_cls = DiT
|
|
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
|
if ckpt_file == "":
|
|
repo_name = "F5-TTS"
|
|
exp_name = "F5TTS_Base"
|
|
ckpt_step = 1200000
|
|
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
|
|
|
elif self.config.modelname == "E2-TTS":
|
|
model_cls = UNetT
|
|
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
|
if ckpt_file == "":
|
|
repo_name = "E2-TTS"
|
|
exp_name = "E2TTS_Base"
|
|
ckpt_step = 1200000
|
|
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
|
|
|
self.model = load_model(model_cls, model_cfg, ckpt_file,
|
|
self.config.vocab_file)
|
|
self.model = self.model.to(self.config.device)
|
|
|
|
def setup_voice(self):
|
|
main_voice = {"ref_audio": self.config.ref_audio_fn,
|
|
"ref_text": self.config.ref_text}
|
|
if "voices" not in self.config:
|
|
voices = {"main": main_voice}
|
|
else:
|
|
voices = self.config["voices"]
|
|
voices["main"] = main_voice
|
|
for voice in voices:
|
|
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
|
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
|
)
|
|
print("Voice:", voice)
|
|
print("Ref_audio:", voices[voice]["ref_audio"])
|
|
print("Ref_text:", voices[voice]["ref_text"])
|
|
self.voices = voices
|
|
|
|
def _inference_stream(self, prompt):
|
|
text_gen = prompt
|
|
remove_silence = False
|
|
generated_audio_segments = []
|
|
reg1 = r"(?=\[\w+\])"
|
|
chunks = re.split(reg1, text_gen)
|
|
reg2 = r"\[(\w+)\]"
|
|
for text in chunks:
|
|
match = re.match(reg2, text)
|
|
if match:
|
|
voice = match[1]
|
|
else:
|
|
print("No voice tag found, using main.")
|
|
voice = "main"
|
|
if voice not in self.voices:
|
|
print(f"Voice {voice} not found, using main.")
|
|
voice = "main"
|
|
text = re.sub(reg2, "", text)
|
|
gen_text = text.strip()
|
|
ref_audio = self.voices[voice]["ref_audio"]
|
|
ref_text = self.voices[voice]["ref_text"]
|
|
print(f"Voice: {voice}, {self.model}")
|
|
audio, final_sample_rate, spectragram = \
|
|
infer_process(ref_audio, ref_text, gen_text, self.model)
|
|
yield {
|
|
'audio':audio,
|
|
'sample_rate':final_sample_rate,
|
|
'spectragram':spectragram,
|
|
'finish':False
|
|
}
|
|
yield {
|
|
'finish':True
|
|
}
|
|
|
|
def _inference(self, prompt):
|
|
generated_audio_segments = []
|
|
remove_silence = self.config.remove_silence or False
|
|
final_sample_rate = 24000
|
|
for d in self._inference_stream(prompt):
|
|
if not d['finish']:
|
|
audio = d['audio']
|
|
final_sample_rate = d['sample_rate']
|
|
generated_audio_segments.append(audio)
|
|
|
|
if generated_audio_segments:
|
|
final_wave = np.concatenate(generated_audio_segments)
|
|
fn = temp_file(suffix='.wav')
|
|
with open(fn, "wb") as f:
|
|
sf.write(f.name, final_wave, final_sample_rate)
|
|
# Remove silence
|
|
if remove_silence:
|
|
remove_silence_for_generated_wav(f.name)
|
|
return fn
|
|
|
|
def init():
|
|
g = ServerEnv()
|
|
f5 = F5TTS()
|
|
g.infer_stream = awaitify(f5._inference_stream)
|
|
g.infer = awaitify(f5._inference)
|
|
|
|
if __name__ == '__main__':
|
|
webapp(init)
|