bugfix
This commit is contained in:
parent
60fab4e937
commit
16265edbf4
@ -3,6 +3,7 @@
|
|||||||
"sample_rate":16000,
|
"sample_rate":16000,
|
||||||
"remove_silence":false,
|
"remove_silence":false,
|
||||||
"modelname":"F5-TTS",
|
"modelname":"F5-TTS",
|
||||||
|
"device":"cuda:0",
|
||||||
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav",
|
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav",
|
||||||
"ref_text":"对,这就是我,万人敬仰的太乙真人。",
|
"ref_text":"对,这就是我,万人敬仰的太乙真人。",
|
||||||
"cross_fade_duration":0
|
"cross_fade_duration":0
|
||||||
|
64
f5tts.py
64
f5tts.py
@ -1,4 +1,5 @@
|
|||||||
import time
|
from time import time
|
||||||
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import codecs
|
import codecs
|
||||||
import re
|
import re
|
||||||
@ -30,11 +31,6 @@ ode_method = "euler"
|
|||||||
sway_sampling_coef = -1.0
|
sway_sampling_coef = -1.0
|
||||||
speed = 1.0
|
speed = 1.0
|
||||||
|
|
||||||
F5TTS_model_cfg = dict(
|
|
||||||
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
|
||||||
)
|
|
||||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
|
||||||
|
|
||||||
def chunk_text(text, max_chars=135):
|
def chunk_text(text, max_chars=135):
|
||||||
"""
|
"""
|
||||||
Splits the input text into chunks, each with a maximum number of characters.
|
Splits the input text into chunks, each with a maximum number of characters.
|
||||||
@ -69,8 +65,9 @@ class F5TTS:
|
|||||||
self.remove_silence = config.remove_silence
|
self.remove_silence = config.remove_silence
|
||||||
self.modelname = config.modelname
|
self.modelname = config.modelname
|
||||||
self.ref_audio_fn = config.ref_audio_fn
|
self.ref_audio_fn = config.ref_audio_fn
|
||||||
|
self.zmq_url = config.zmq_url
|
||||||
self.ref_text = config.ref_text
|
self.ref_text = config.ref_text
|
||||||
self.model= self.load_model(self.modelname)
|
self.device = config.device
|
||||||
self.cross_fade_duration = config.cross_fade_duration
|
self.cross_fade_duration = config.cross_fade_duration
|
||||||
self.gen_ref_audio()
|
self.gen_ref_audio()
|
||||||
self.gen_ref_text()
|
self.gen_ref_text()
|
||||||
@ -78,7 +75,7 @@ class F5TTS:
|
|||||||
try:
|
try:
|
||||||
print(f"Load vocos from local path {vocos_local_path}")
|
print(f"Load vocos from local path {vocos_local_path}")
|
||||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
|
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device)
|
||||||
vocos.load_state_dict(state_dict)
|
vocos.load_state_dict(state_dict)
|
||||||
vocos.eval()
|
vocos.eval()
|
||||||
self.vocos = vocos
|
self.vocos = vocos
|
||||||
@ -87,19 +84,25 @@ class F5TTS:
|
|||||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||||
self.vocos = vocos
|
self.vocos = vocos
|
||||||
|
|
||||||
|
self.F5TTS_model_cfg = dict(
|
||||||
|
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
||||||
|
)
|
||||||
|
self.E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||||
|
self.model= self.load_model(self.modelname)
|
||||||
|
|
||||||
|
|
||||||
def gen_ref_audio(self):
|
def gen_ref_audio(self):
|
||||||
"""
|
"""
|
||||||
gen ref_audio
|
gen ref_audio
|
||||||
"""
|
"""
|
||||||
audio, sr = torchaudio.load(ref_audio)
|
audio, sr = torchaudio.load(self.ref_audio_fn)
|
||||||
if audio.shape[0] > 1:
|
if audio.shape[0] > 1:
|
||||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||||
if rms < target_rms:
|
if rms < target_rms:
|
||||||
audio = audio * target_rms / rms
|
audio = audio * target_rms / rms
|
||||||
if sr != target_sample_rate:
|
if sr != self.sample_rate:
|
||||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
|
||||||
audio = resampler(audio)
|
audio = resampler(audio)
|
||||||
self.ref_audio = audio
|
self.ref_audio = audio
|
||||||
|
|
||||||
@ -138,29 +141,29 @@ class F5TTS:
|
|||||||
method="euler"
|
method="euler"
|
||||||
),
|
),
|
||||||
vocab_char_map=vocab_char_map,
|
vocab_char_map=vocab_char_map,
|
||||||
).to(device)
|
).to(self.device)
|
||||||
|
|
||||||
model = load_checkpoint(model, ckpt_path, device, use_ema = True)
|
model = load_checkpoint(model, ckpt_path, self.device, use_ema = True)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_model(self, model):
|
def load_model(self, model):
|
||||||
if model == 'F5-TTS':
|
if model == 'F5-TTS':
|
||||||
ret = self._load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000)
|
||||||
return ret
|
return ret
|
||||||
if model == 'E2-TTS':
|
if model == 'E2-TTS':
|
||||||
return self._load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
|
return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000)
|
||||||
|
|
||||||
def split_text(self, text):
|
def split_text(self, text):
|
||||||
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
||||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||||
print('ref_text', ref_text)
|
print('ref_text', ref_text)
|
||||||
|
|
||||||
def inference(self, prmpt, stream=False):
|
def inference(self, prompt, stream=False):
|
||||||
generated_waves = []
|
generated_waves = []
|
||||||
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
||||||
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
|
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
|
||||||
for gen_text in gen_text_Batches:
|
for gen_text in gen_text_batches:
|
||||||
# Prepare the text
|
# Prepare the text
|
||||||
text_list = [self.ref_text + gen_text]
|
text_list = [self.ref_text + gen_text]
|
||||||
final_text_list = convert_char_to_pinyin(text_list)
|
final_text_list = convert_char_to_pinyin(text_list)
|
||||||
@ -173,6 +176,7 @@ class F5TTS:
|
|||||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
||||||
|
|
||||||
# inference
|
# inference
|
||||||
|
print(f'{self.device=}....')
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
generated, _ = self.model.sample(
|
generated, _ = self.model.sample(
|
||||||
cond=self.ref_audio,
|
cond=self.ref_audio,
|
||||||
@ -194,6 +198,7 @@ class F5TTS:
|
|||||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||||
generated_waves.append(generated_wave)
|
generated_waves.append(generated_wave)
|
||||||
if stream:
|
if stream:
|
||||||
|
print(f'here ........{stream}')
|
||||||
return
|
return
|
||||||
if self.cross_fade_duration <= 0:
|
if self.cross_fade_duration <= 0:
|
||||||
# Simply concatenate
|
# Simply concatenate
|
||||||
@ -201,7 +206,9 @@ class F5TTS:
|
|||||||
else:
|
else:
|
||||||
final_wave = self.cross_fade_wave(generated_waves)
|
final_wave = self.cross_fade_wave(generated_waves)
|
||||||
fn = self.write_wave(final_wave)
|
fn = self.write_wave(final_wave)
|
||||||
return fn
|
print(f'here ........{stream}, {fn=}')
|
||||||
|
yield fn
|
||||||
|
return
|
||||||
|
|
||||||
def cross_fade_wave(self, waves):
|
def cross_fade_wave(self, waves):
|
||||||
final_wave = generated_waves[0]
|
final_wave = generated_waves[0]
|
||||||
@ -249,20 +256,35 @@ class F5TTS:
|
|||||||
data = DictObject(**json.loads(msg))
|
data = DictObject(**json.loads(msg))
|
||||||
print(data)
|
print(data)
|
||||||
t1 = time()
|
t1 = time()
|
||||||
f = self.inference(data.prompt)
|
for wav in self.inference(data.prompt, stream=data.stream):
|
||||||
|
if data.stream:
|
||||||
|
d = {
|
||||||
|
"reqid":data.reqid,
|
||||||
|
"b64wave":b64str(wav),
|
||||||
|
"finish":False
|
||||||
|
}
|
||||||
|
self.replier.send(json.dumps(d))
|
||||||
|
else:
|
||||||
t2 = time()
|
t2 = time()
|
||||||
d = {
|
d = {
|
||||||
"audio_file":f,
|
"reqid":data.reqid,
|
||||||
|
"audio_file":wav,
|
||||||
"time_cost":t2 - t1
|
"time_cost":t2 - t1
|
||||||
}
|
}
|
||||||
print(f'{d}')
|
print(f'{d}')
|
||||||
return json.dumps(d)
|
return json.dumps(d)
|
||||||
|
t2 = time()
|
||||||
|
d = {
|
||||||
|
"reqid":data.reqid,
|
||||||
|
"time_cost":t2 - t1,
|
||||||
|
"finish":True
|
||||||
|
}
|
||||||
|
return json.dumps(d)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
workdir = os.getcwd()
|
workdir = os.getcwd()
|
||||||
config = getConfig(workdir)
|
config = getConfig(workdir)
|
||||||
print(f'{config=}')
|
|
||||||
tts = F5TTS()
|
tts = F5TTS()
|
||||||
print('here')
|
print('here')
|
||||||
tts.run()
|
tts.run()
|
||||||
|
BIN
samples/test.wav
Normal file
BIN
samples/test.wav
Normal file
Binary file not shown.
BIN
samples/test_en_1_ref_short.wav
Normal file
BIN
samples/test_en_1_ref_short.wav
Normal file
Binary file not shown.
BIN
samples/test_zh_1_ref_short.wav
Normal file
BIN
samples/test_zh_1_ref_short.wav
Normal file
Binary file not shown.
43
zmq_client.py
Normal file
43
zmq_client.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from appPublic.dictObject import DictObject
|
||||||
|
from appPublic.zmq_reqrep import ZmqRequester
|
||||||
|
from appPublic.jsonConfig import getConfig
|
||||||
|
from appPublic.uniqueID import getID
|
||||||
|
|
||||||
|
zmq_url = "tcp://127.0.0.1:9999"
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
class ASRClient:
|
||||||
|
def __init__(self, zmq_url):
|
||||||
|
self.zmq_url = zmq_url
|
||||||
|
self.requester = ZmqRequester(self.zmq_url)
|
||||||
|
|
||||||
|
def generate(self, prompt):
|
||||||
|
d = {
|
||||||
|
"prompt":prompt,
|
||||||
|
"reqid":getID()
|
||||||
|
}
|
||||||
|
msg = json.dumps(d)
|
||||||
|
resp = self.requester.send(msg)
|
||||||
|
if resp != None:
|
||||||
|
ret = json.loads(resp)
|
||||||
|
print(f'response={ret}')
|
||||||
|
else:
|
||||||
|
print(f'response is None')
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
print(f'running {self.zmq_url}')
|
||||||
|
while True:
|
||||||
|
print('input audio_file:')
|
||||||
|
af = input()
|
||||||
|
if len(af) > 0:
|
||||||
|
self.generate(af)
|
||||||
|
print('ended ...')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
workdir = os.getcwd()
|
||||||
|
config = getConfig(workdir)
|
||||||
|
asr = ASRClient(config.zmq_url or zmq_url)
|
||||||
|
asr.run()
|
Loading…
Reference in New Issue
Block a user