f5tts/f5tts.py

197 lines
5.0 KiB
Python
Raw Permalink Normal View History

2024-10-22 19:13:23 +08:00
import sys
sys.path.append('./F5TTS')
import argparse
2024-10-21 18:29:57 +08:00
import codecs
import re
2024-10-22 19:13:23 +08:00
from pathlib import Path
2024-10-21 18:29:57 +08:00
import numpy as np
import soundfile as sf
2024-10-22 19:13:23 +08:00
import tomli
2024-10-21 18:29:57 +08:00
from cached_path import cached_path
2024-10-22 19:13:23 +08:00
from model import DiT, UNetT
from model.utils_infer import (
2024-10-22 22:59:29 +08:00
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
2024-10-22 19:13:23 +08:00
)
2024-10-21 18:29:57 +08:00
import os
import json
2024-10-22 22:59:29 +08:00
from time import time
2024-10-21 18:29:57 +08:00
from appPublic.dictObject import DictObject
from appPublic.zmq_reqrep import ZmqReplier
2024-10-22 22:59:29 +08:00
from appPublic.folderUtils import temp_file
2024-10-21 18:29:57 +08:00
from appPublic.jsonConfig import getConfig
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
class F5TTS:
def __init__(self):
2024-10-22 22:59:29 +08:00
self.config = getConfig()
self.zmq_url = self.config.zmq_url
self.replier = ZmqReplier(self.config.zmq_url, self.generate)
# self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz")
self.load_model()
self.setup_voice()
2024-10-21 18:29:57 +08:00
def run(self):
print(f'running {self.zmq_url}')
self.replier._run()
print('ended ...')
2024-10-22 22:59:29 +08:00
def load_model(self):
# load models
ckpt_file = ''
if self.config.modelname == "F5-TTS":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
if ckpt_file == "":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
elif self.config.modelname == "E2-TTS":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if ckpt_file == "":
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
self.model = load_model(model_cls, model_cfg, ckpt_file,
2024-10-23 18:38:56 +08:00
self.config.vocab_file)
self.model = self.model.to(self.config.device)
2024-10-21 18:29:57 +08:00
def generate(self, d):
msg= d.decode('utf-8')
data = DictObject(**json.loads(msg))
print(data)
t1 = time()
2024-10-23 18:38:56 +08:00
if data.stream:
for wav in self.inference_stream(data.prompt, stream=data.stream):
2024-10-22 11:41:08 +08:00
d = {
"reqid":data.reqid,
"b64wave":b64str(wav),
"finish":False
}
self.replier.send(json.dumps(d))
2024-10-23 18:38:56 +08:00
t2 = time()
d = {
"reqid":data.reqid,
"time_cost":t2 - t1,
"finish":True
}
return json.dumps(d)
else:
audio_fn = self.inference(data.prompt)
t2 = time()
d = {
"reqid":data.reqid,
"audio_file":audio_fn,
"time_cost":t2 - t1
}
print(f'{d}')
return json.dumps(d)
2024-10-21 18:29:57 +08:00
2024-10-22 22:59:29 +08:00
def setup_voice(self):
main_voice = {"ref_audio": self.config.ref_audio_fn,
"ref_text": self.config.ref_text}
if "voices" not in self.config:
voices = {"main": main_voice}
else:
voices = self.config["voices"]
voices["main"] = main_voice
for voice in voices:
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])
self.voices = voices
2024-10-23 18:38:56 +08:00
def inference_stream(self, prompt):
2024-10-22 22:59:29 +08:00
text_gen = prompt
remove_silence = False
generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, text_gen)
reg2 = r"\[(\w+)\]"
for text in chunks:
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in self.voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = self.voices[voice]["ref_audio"]
ref_text = self.voices[voice]["ref_text"]
print(f"Voice: {voice}, {self.model}")
audio, final_sample_rate, spectragram = \
infer_process(ref_audio, ref_text, gen_text, self.model)
2024-10-23 18:38:56 +08:00
yield {
'audio':audio,
'sample_rate':final_sample_rate,
'spectragram':spectragram,
'finish':False
}
yield {
'finish':True
}
def inference(self, prompt):
generated_audio_segments = []
remove_silence = self.config.remove_silence or False
final_sample_rate = 24000
for d in self.inference_stream(prompt):
if not d['finish']:
audio = d['audio']
final_sample_rate = d['sample_rate']
generated_audio_segments.append(audio)
2024-10-22 22:59:29 +08:00
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
fn = temp_file(suffix='.wav')
with open(fn, "wb") as f:
sf.write(f.name, final_wave, final_sample_rate)
# Remove silence
if remove_silence:
remove_silence_for_generated_wav(f.name)
return fn
2024-10-21 18:29:57 +08:00
if __name__ == '__main__':
2024-10-23 18:38:56 +08:00
workdir = os.getcwd()
config = getConfig(workdir, {'workdir':workdir})
print(config.ref_audio_fn)
2024-10-21 18:29:57 +08:00
tts = F5TTS()
print('here')
2024-10-23 18:38:56 +08:00
tts.run()
"""
2024-10-22 22:59:29 +08:00
while True:
print('prompt:')
p = input()
if p != '':
t1 = time()
f = tts.inference(p)
t2 = time()
print(f'{f}, cost {t2-t1} seconds')
2024-10-23 18:38:56 +08:00
"""