diff --git a/conf/config.json b/conf/config.json index c59ec99..cad1a9e 100644 --- a/conf/config.json +++ b/conf/config.json @@ -3,6 +3,7 @@ "sample_rate":16000, "remove_silence":false, "modelname":"F5-TTS", + "device":"cuda:0", "ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav", "ref_text":"对,这就是我,万人敬仰的太乙真人。", "cross_fade_duration":0 diff --git a/f5tts.py b/f5tts.py index fb16a09..f961d45 100644 --- a/f5tts.py +++ b/f5tts.py @@ -1,4 +1,5 @@ -import time +from time import time +import torch from pathlib import Path import codecs import re @@ -30,11 +31,6 @@ ode_method = "euler" sway_sampling_coef = -1.0 speed = 1.0 -F5TTS_model_cfg = dict( - dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 -) -E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - def chunk_text(text, max_chars=135): """ Splits the input text into chunks, each with a maximum number of characters. @@ -69,8 +65,9 @@ class F5TTS: self.remove_silence = config.remove_silence self.modelname = config.modelname self.ref_audio_fn = config.ref_audio_fn + self.zmq_url = config.zmq_url self.ref_text = config.ref_text - self.model= self.load_model(self.modelname) + self.device = config.device self.cross_fade_duration = config.cross_fade_duration self.gen_ref_audio() self.gen_ref_text() @@ -78,7 +75,7 @@ class F5TTS: try: print(f"Load vocos from local path {vocos_local_path}") vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") - state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) + state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device) vocos.load_state_dict(state_dict) vocos.eval() self.vocos = vocos @@ -87,19 +84,25 @@ class F5TTS: vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") self.vocos = vocos + self.F5TTS_model_cfg = dict( + dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 + ) + self.E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + self.model= self.load_model(self.modelname) + def gen_ref_audio(self): """ gen ref_audio """ - audio, sr = torchaudio.load(ref_audio) + audio, sr = torchaudio.load(self.ref_audio_fn) if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms - if sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(sr, target_sample_rate) + if sr != self.sample_rate: + resampler = torchaudio.transforms.Resample(sr, self.sample_rate) audio = resampler(audio) self.ref_audio = audio @@ -138,29 +141,29 @@ class F5TTS: method="euler" ), vocab_char_map=vocab_char_map, - ).to(device) + ).to(self.device) - model = load_checkpoint(model, ckpt_path, device, use_ema = True) + model = load_checkpoint(model, ckpt_path, self.device, use_ema = True) return model def load_model(self, model): if model == 'F5-TTS': - ret = self._load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000) + ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000) return ret if model == 'E2-TTS': - return self._load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000) + return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000) def split_text(self, text): max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate)) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) print('ref_text', ref_text) - def inference(self, prmpt, stream=False): + def inference(self, prompt, stream=False): generated_waves = [] max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate)) gen_text_batches = chunk_text(prompt, max_chars=max_chars) - for gen_text in gen_text_Batches: + for gen_text in gen_text_batches: # Prepare the text text_list = [self.ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) @@ -173,6 +176,7 @@ class F5TTS: duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) # inference + print(f'{self.device=}....') with torch.inference_mode(): generated, _ = self.model.sample( cond=self.ref_audio, @@ -194,6 +198,7 @@ class F5TTS: generated_wave = generated_wave.squeeze().cpu().numpy() generated_waves.append(generated_wave) if stream: + print(f'here ........{stream}') return if self.cross_fade_duration <= 0: # Simply concatenate @@ -201,7 +206,9 @@ class F5TTS: else: final_wave = self.cross_fade_wave(generated_waves) fn = self.write_wave(final_wave) - return fn + print(f'here ........{stream}, {fn=}') + yield fn + return def cross_fade_wave(self, waves): final_wave = generated_waves[0] @@ -249,20 +256,35 @@ class F5TTS: data = DictObject(**json.loads(msg)) print(data) t1 = time() - f = self.inference(data.prompt) + for wav in self.inference(data.prompt, stream=data.stream): + if data.stream: + d = { + "reqid":data.reqid, + "b64wave":b64str(wav), + "finish":False + } + self.replier.send(json.dumps(d)) + else: + t2 = time() + d = { + "reqid":data.reqid, + "audio_file":wav, + "time_cost":t2 - t1 + } + print(f'{d}') + return json.dumps(d) t2 = time() d = { - "audio_file":f, - "time_cost":t2 - t1 + "reqid":data.reqid, + "time_cost":t2 - t1, + "finish":True } - print(f'{d}') return json.dumps(d) if __name__ == '__main__': workdir = os.getcwd() config = getConfig(workdir) - print(f'{config=}') tts = F5TTS() print('here') tts.run() diff --git a/samples/test.wav b/samples/test.wav new file mode 100644 index 0000000..4027401 Binary files /dev/null and b/samples/test.wav differ diff --git a/samples/test_en_1_ref_short.wav b/samples/test_en_1_ref_short.wav new file mode 100644 index 0000000..3c593c3 Binary files /dev/null and b/samples/test_en_1_ref_short.wav differ diff --git a/samples/test_zh_1_ref_short.wav b/samples/test_zh_1_ref_short.wav new file mode 100644 index 0000000..8cc055e Binary files /dev/null and b/samples/test_zh_1_ref_short.wav differ diff --git a/zmq_client.py b/zmq_client.py new file mode 100644 index 0000000..ed0179d --- /dev/null +++ b/zmq_client.py @@ -0,0 +1,43 @@ +import json +import os + +from appPublic.dictObject import DictObject +from appPublic.zmq_reqrep import ZmqRequester +from appPublic.jsonConfig import getConfig +from appPublic.uniqueID import getID + +zmq_url = "tcp://127.0.0.1:9999" +from time import time + +class ASRClient: + def __init__(self, zmq_url): + self.zmq_url = zmq_url + self.requester = ZmqRequester(self.zmq_url) + + def generate(self, prompt): + d = { + "prompt":prompt, + "reqid":getID() + } + msg = json.dumps(d) + resp = self.requester.send(msg) + if resp != None: + ret = json.loads(resp) + print(f'response={ret}') + else: + print(f'response is None') + + def run(self): + print(f'running {self.zmq_url}') + while True: + print('input audio_file:') + af = input() + if len(af) > 0: + self.generate(af) + print('ended ...') + +if __name__ == '__main__': + workdir = os.getcwd() + config = getConfig(workdir) + asr = ASRClient(config.zmq_url or zmq_url) + asr.run()