This commit is contained in:
yumoqing 2024-10-22 11:41:08 +08:00
parent 60fab4e937
commit 16265edbf4
6 changed files with 89 additions and 23 deletions

View File

@ -3,6 +3,7 @@
"sample_rate":16000,
"remove_silence":false,
"modelname":"F5-TTS",
"device":"cuda:0",
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav",
"ref_text":"对,这就是我,万人敬仰的太乙真人。",
"cross_fade_duration":0

View File

@ -1,4 +1,5 @@
import time
from time import time
import torch
from pathlib import Path
import codecs
import re
@ -30,11 +31,6 @@ ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
F5TTS_model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
@ -69,8 +65,9 @@ class F5TTS:
self.remove_silence = config.remove_silence
self.modelname = config.modelname
self.ref_audio_fn = config.ref_audio_fn
self.zmq_url = config.zmq_url
self.ref_text = config.ref_text
self.model= self.load_model(self.modelname)
self.device = config.device
self.cross_fade_duration = config.cross_fade_duration
self.gen_ref_audio()
self.gen_ref_text()
@ -78,7 +75,7 @@ class F5TTS:
try:
print(f"Load vocos from local path {vocos_local_path}")
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device)
vocos.load_state_dict(state_dict)
vocos.eval()
self.vocos = vocos
@ -87,19 +84,25 @@ class F5TTS:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
self.vocos = vocos
self.F5TTS_model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
self.E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
self.model= self.load_model(self.modelname)
def gen_ref_audio(self):
"""
gen ref_audio
"""
audio, sr = torchaudio.load(ref_audio)
audio, sr = torchaudio.load(self.ref_audio_fn)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
audio = resampler(audio)
self.ref_audio = audio
@ -138,29 +141,29 @@ class F5TTS:
method="euler"
),
vocab_char_map=vocab_char_map,
).to(device)
).to(self.device)
model = load_checkpoint(model, ckpt_path, device, use_ema = True)
model = load_checkpoint(model, ckpt_path, self.device, use_ema = True)
return model
def load_model(self, model):
if model == 'F5-TTS':
ret = self._load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000)
return ret
if model == 'E2-TTS':
return self._load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000)
def split_text(self, text):
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
print('ref_text', ref_text)
def inference(self, prmpt, stream=False):
def inference(self, prompt, stream=False):
generated_waves = []
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
for gen_text in gen_text_Batches:
for gen_text in gen_text_batches:
# Prepare the text
text_list = [self.ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
@ -173,6 +176,7 @@ class F5TTS:
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
# inference
print(f'{self.device=}....')
with torch.inference_mode():
generated, _ = self.model.sample(
cond=self.ref_audio,
@ -194,6 +198,7 @@ class F5TTS:
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
if stream:
print(f'here ........{stream}')
return
if self.cross_fade_duration <= 0:
# Simply concatenate
@ -201,7 +206,9 @@ class F5TTS:
else:
final_wave = self.cross_fade_wave(generated_waves)
fn = self.write_wave(final_wave)
return fn
print(f'here ........{stream}, {fn=}')
yield fn
return
def cross_fade_wave(self, waves):
final_wave = generated_waves[0]
@ -249,20 +256,35 @@ class F5TTS:
data = DictObject(**json.loads(msg))
print(data)
t1 = time()
f = self.inference(data.prompt)
for wav in self.inference(data.prompt, stream=data.stream):
if data.stream:
d = {
"reqid":data.reqid,
"b64wave":b64str(wav),
"finish":False
}
self.replier.send(json.dumps(d))
else:
t2 = time()
d = {
"reqid":data.reqid,
"audio_file":wav,
"time_cost":t2 - t1
}
print(f'{d}')
return json.dumps(d)
t2 = time()
d = {
"audio_file":f,
"time_cost":t2 - t1
"reqid":data.reqid,
"time_cost":t2 - t1,
"finish":True
}
print(f'{d}')
return json.dumps(d)
if __name__ == '__main__':
workdir = os.getcwd()
config = getConfig(workdir)
print(f'{config=}')
tts = F5TTS()
print('here')
tts.run()

BIN
samples/test.wav Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

43
zmq_client.py Normal file
View File

@ -0,0 +1,43 @@
import json
import os
from appPublic.dictObject import DictObject
from appPublic.zmq_reqrep import ZmqRequester
from appPublic.jsonConfig import getConfig
from appPublic.uniqueID import getID
zmq_url = "tcp://127.0.0.1:9999"
from time import time
class ASRClient:
def __init__(self, zmq_url):
self.zmq_url = zmq_url
self.requester = ZmqRequester(self.zmq_url)
def generate(self, prompt):
d = {
"prompt":prompt,
"reqid":getID()
}
msg = json.dumps(d)
resp = self.requester.send(msg)
if resp != None:
ret = json.loads(resp)
print(f'response={ret}')
else:
print(f'response is None')
def run(self):
print(f'running {self.zmq_url}')
while True:
print('input audio_file:')
af = input()
if len(af) > 0:
self.generate(af)
print('ended ...')
if __name__ == '__main__':
workdir = os.getcwd()
config = getConfig(workdir)
asr = ASRClient(config.zmq_url or zmq_url)
asr.run()