This commit is contained in:
yumoqing 2024-10-23 18:38:56 +08:00
parent ffaa06c8fc
commit 0bc047429f
6 changed files with 51 additions and 33 deletions

View File

@ -5,8 +5,8 @@
"remove_silence":false,
"modelname":"F5-TTS",
"device":"cuda:0",
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav",
"ref_text":"对,这就是我,万人敬仰的太乙真人。",
"ref_audio_fn":"$[workdir]$/samples/ttt.wav",
"ref_text":"快点吃饭,上课要迟到了。",
"cross_fade_duration":0
}

View File

@ -72,30 +72,22 @@ class F5TTS:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
self.model = load_model(model_cls, model_cfg, ckpt_file,
self.config.vocab_file).to(self.config.device)
self.config.vocab_file)
self.model = self.model.to(self.config.device)
def generate(self, d):
msg= d.decode('utf-8')
data = DictObject(**json.loads(msg))
print(data)
t1 = time()
for wav in self.inference(data.prompt, stream=data.stream):
if data.stream:
for wav in self.inference_stream(data.prompt, stream=data.stream):
d = {
"reqid":data.reqid,
"b64wave":b64str(wav),
"finish":False
}
self.replier.send(json.dumps(d))
else:
t2 = time()
d = {
"reqid":data.reqid,
"audio_file":wav,
"time_cost":t2 - t1
}
print(f'{d}')
return json.dumps(d)
t2 = time()
d = {
"reqid":data.reqid,
@ -103,6 +95,16 @@ class F5TTS:
"finish":True
}
return json.dumps(d)
else:
audio_fn = self.inference(data.prompt)
t2 = time()
d = {
"reqid":data.reqid,
"audio_file":audio_fn,
"time_cost":t2 - t1
}
print(f'{d}')
return json.dumps(d)
def setup_voice(self):
main_voice = {"ref_audio": self.config.ref_audio_fn,
@ -121,8 +123,7 @@ class F5TTS:
print("Ref_text:", voices[voice]["ref_text"])
self.voices = voices
def inference(self, prompt):
def inference_stream(self, prompt):
text_gen = prompt
remove_silence = False
generated_audio_segments = []
@ -146,6 +147,24 @@ class F5TTS:
print(f"Voice: {voice}, {self.model}")
audio, final_sample_rate, spectragram = \
infer_process(ref_audio, ref_text, gen_text, self.model)
yield {
'audio':audio,
'sample_rate':final_sample_rate,
'spectragram':spectragram,
'finish':False
}
yield {
'finish':True
}
def inference(self, prompt):
generated_audio_segments = []
remove_silence = self.config.remove_silence or False
final_sample_rate = 24000
for d in self.inference_stream(prompt):
if not d['finish']:
audio = d['audio']
final_sample_rate = d['sample_rate']
generated_audio_segments.append(audio)
if generated_audio_segments:
@ -158,13 +177,14 @@ class F5TTS:
remove_silence_for_generated_wav(f.name)
return fn
if __name__ == '__main__':
# workdir = os.getcwd()
# config = getConfig(workdir)
workdir = os.getcwd()
config = getConfig(workdir, {'workdir':workdir})
print(config.ref_audio_fn)
tts = F5TTS()
print('here')
# tts.run()
tts.run()
"""
while True:
print('prompt:')
p = input()
@ -173,7 +193,4 @@ if __name__ == '__main__':
f = tts.inference(p)
t2 = time()
print(f'{f}, cost {t2-t1} seconds')
"""

BIN
samples/test2.m4a Normal file

Binary file not shown.

BIN
samples/ttt.m4a Normal file

Binary file not shown.

BIN
samples/ttt.wav Normal file

Binary file not shown.

View File

@ -1,3 +1,4 @@
import sys
import json
import os
@ -9,7 +10,7 @@ from appPublic.uniqueID import getID
zmq_url = "tcp://127.0.0.1:9999"
from time import time
class ASRClient:
class F5TTSClient:
def __init__(self, zmq_url):
self.zmq_url = zmq_url
self.requester = ZmqRequester(self.zmq_url)