bugfix
This commit is contained in:
parent
60fab4e937
commit
16265edbf4
@ -3,6 +3,7 @@
|
||||
"sample_rate":16000,
|
||||
"remove_silence":false,
|
||||
"modelname":"F5-TTS",
|
||||
"device":"cuda:0",
|
||||
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav",
|
||||
"ref_text":"对,这就是我,万人敬仰的太乙真人。",
|
||||
"cross_fade_duration":0
|
||||
|
64
f5tts.py
64
f5tts.py
@ -1,4 +1,5 @@
|
||||
import time
|
||||
from time import time
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import codecs
|
||||
import re
|
||||
@ -30,11 +31,6 @@ ode_method = "euler"
|
||||
sway_sampling_coef = -1.0
|
||||
speed = 1.0
|
||||
|
||||
F5TTS_model_cfg = dict(
|
||||
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
||||
)
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
|
||||
def chunk_text(text, max_chars=135):
|
||||
"""
|
||||
Splits the input text into chunks, each with a maximum number of characters.
|
||||
@ -69,8 +65,9 @@ class F5TTS:
|
||||
self.remove_silence = config.remove_silence
|
||||
self.modelname = config.modelname
|
||||
self.ref_audio_fn = config.ref_audio_fn
|
||||
self.zmq_url = config.zmq_url
|
||||
self.ref_text = config.ref_text
|
||||
self.model= self.load_model(self.modelname)
|
||||
self.device = config.device
|
||||
self.cross_fade_duration = config.cross_fade_duration
|
||||
self.gen_ref_audio()
|
||||
self.gen_ref_text()
|
||||
@ -78,7 +75,7 @@ class F5TTS:
|
||||
try:
|
||||
print(f"Load vocos from local path {vocos_local_path}")
|
||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device)
|
||||
vocos.load_state_dict(state_dict)
|
||||
vocos.eval()
|
||||
self.vocos = vocos
|
||||
@ -87,19 +84,25 @@ class F5TTS:
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
self.vocos = vocos
|
||||
|
||||
self.F5TTS_model_cfg = dict(
|
||||
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
||||
)
|
||||
self.E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
self.model= self.load_model(self.modelname)
|
||||
|
||||
|
||||
def gen_ref_audio(self):
|
||||
"""
|
||||
gen ref_audio
|
||||
"""
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
audio, sr = torchaudio.load(self.ref_audio_fn)
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < target_rms:
|
||||
audio = audio * target_rms / rms
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
||||
if sr != self.sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
|
||||
audio = resampler(audio)
|
||||
self.ref_audio = audio
|
||||
|
||||
@ -138,29 +141,29 @@ class F5TTS:
|
||||
method="euler"
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
).to(self.device)
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, device, use_ema = True)
|
||||
model = load_checkpoint(model, ckpt_path, self.device, use_ema = True)
|
||||
|
||||
return model
|
||||
|
||||
def load_model(self, model):
|
||||
if model == 'F5-TTS':
|
||||
ret = self._load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
||||
ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000)
|
||||
return ret
|
||||
if model == 'E2-TTS':
|
||||
return self._load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
|
||||
return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000)
|
||||
|
||||
def split_text(self, text):
|
||||
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
print('ref_text', ref_text)
|
||||
|
||||
def inference(self, prmpt, stream=False):
|
||||
def inference(self, prompt, stream=False):
|
||||
generated_waves = []
|
||||
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
||||
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
|
||||
for gen_text in gen_text_Batches:
|
||||
for gen_text in gen_text_batches:
|
||||
# Prepare the text
|
||||
text_list = [self.ref_text + gen_text]
|
||||
final_text_list = convert_char_to_pinyin(text_list)
|
||||
@ -173,6 +176,7 @@ class F5TTS:
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# inference
|
||||
print(f'{self.device=}....')
|
||||
with torch.inference_mode():
|
||||
generated, _ = self.model.sample(
|
||||
cond=self.ref_audio,
|
||||
@ -194,6 +198,7 @@ class F5TTS:
|
||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||
generated_waves.append(generated_wave)
|
||||
if stream:
|
||||
print(f'here ........{stream}')
|
||||
return
|
||||
if self.cross_fade_duration <= 0:
|
||||
# Simply concatenate
|
||||
@ -201,7 +206,9 @@ class F5TTS:
|
||||
else:
|
||||
final_wave = self.cross_fade_wave(generated_waves)
|
||||
fn = self.write_wave(final_wave)
|
||||
return fn
|
||||
print(f'here ........{stream}, {fn=}')
|
||||
yield fn
|
||||
return
|
||||
|
||||
def cross_fade_wave(self, waves):
|
||||
final_wave = generated_waves[0]
|
||||
@ -249,20 +256,35 @@ class F5TTS:
|
||||
data = DictObject(**json.loads(msg))
|
||||
print(data)
|
||||
t1 = time()
|
||||
f = self.inference(data.prompt)
|
||||
for wav in self.inference(data.prompt, stream=data.stream):
|
||||
if data.stream:
|
||||
d = {
|
||||
"reqid":data.reqid,
|
||||
"b64wave":b64str(wav),
|
||||
"finish":False
|
||||
}
|
||||
self.replier.send(json.dumps(d))
|
||||
else:
|
||||
t2 = time()
|
||||
d = {
|
||||
"audio_file":f,
|
||||
"reqid":data.reqid,
|
||||
"audio_file":wav,
|
||||
"time_cost":t2 - t1
|
||||
}
|
||||
print(f'{d}')
|
||||
return json.dumps(d)
|
||||
t2 = time()
|
||||
d = {
|
||||
"reqid":data.reqid,
|
||||
"time_cost":t2 - t1,
|
||||
"finish":True
|
||||
}
|
||||
return json.dumps(d)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
workdir = os.getcwd()
|
||||
config = getConfig(workdir)
|
||||
print(f'{config=}')
|
||||
tts = F5TTS()
|
||||
print('here')
|
||||
tts.run()
|
||||
|
BIN
samples/test.wav
Normal file
BIN
samples/test.wav
Normal file
Binary file not shown.
BIN
samples/test_en_1_ref_short.wav
Normal file
BIN
samples/test_en_1_ref_short.wav
Normal file
Binary file not shown.
BIN
samples/test_zh_1_ref_short.wav
Normal file
BIN
samples/test_zh_1_ref_short.wav
Normal file
Binary file not shown.
43
zmq_client.py
Normal file
43
zmq_client.py
Normal file
@ -0,0 +1,43 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from appPublic.dictObject import DictObject
|
||||
from appPublic.zmq_reqrep import ZmqRequester
|
||||
from appPublic.jsonConfig import getConfig
|
||||
from appPublic.uniqueID import getID
|
||||
|
||||
zmq_url = "tcp://127.0.0.1:9999"
|
||||
from time import time
|
||||
|
||||
class ASRClient:
|
||||
def __init__(self, zmq_url):
|
||||
self.zmq_url = zmq_url
|
||||
self.requester = ZmqRequester(self.zmq_url)
|
||||
|
||||
def generate(self, prompt):
|
||||
d = {
|
||||
"prompt":prompt,
|
||||
"reqid":getID()
|
||||
}
|
||||
msg = json.dumps(d)
|
||||
resp = self.requester.send(msg)
|
||||
if resp != None:
|
||||
ret = json.loads(resp)
|
||||
print(f'response={ret}')
|
||||
else:
|
||||
print(f'response is None')
|
||||
|
||||
def run(self):
|
||||
print(f'running {self.zmq_url}')
|
||||
while True:
|
||||
print('input audio_file:')
|
||||
af = input()
|
||||
if len(af) > 0:
|
||||
self.generate(af)
|
||||
print('ended ...')
|
||||
|
||||
if __name__ == '__main__':
|
||||
workdir = os.getcwd()
|
||||
config = getConfig(workdir)
|
||||
asr = ASRClient(config.zmq_url or zmq_url)
|
||||
asr.run()
|
Loading…
Reference in New Issue
Block a user