This commit is contained in:
yumoqing 2024-10-23 18:38:56 +08:00
parent ffaa06c8fc
commit 0bc047429f
6 changed files with 51 additions and 33 deletions

View File

@ -5,8 +5,8 @@
"remove_silence":false, "remove_silence":false,
"modelname":"F5-TTS", "modelname":"F5-TTS",
"device":"cuda:0", "device":"cuda:0",
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav", "ref_audio_fn":"$[workdir]$/samples/ttt.wav",
"ref_text":"对,这就是我,万人敬仰的太乙真人。", "ref_text":"快点吃饭,上课要迟到了。",
"cross_fade_duration":0 "cross_fade_duration":0
} }

View File

@ -72,30 +72,22 @@ class F5TTS:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
self.model = load_model(model_cls, model_cfg, ckpt_file, self.model = load_model(model_cls, model_cfg, ckpt_file,
self.config.vocab_file).to(self.config.device) self.config.vocab_file)
self.model = self.model.to(self.config.device)
def generate(self, d): def generate(self, d):
msg= d.decode('utf-8') msg= d.decode('utf-8')
data = DictObject(**json.loads(msg)) data = DictObject(**json.loads(msg))
print(data) print(data)
t1 = time() t1 = time()
for wav in self.inference(data.prompt, stream=data.stream):
if data.stream: if data.stream:
for wav in self.inference_stream(data.prompt, stream=data.stream):
d = { d = {
"reqid":data.reqid, "reqid":data.reqid,
"b64wave":b64str(wav), "b64wave":b64str(wav),
"finish":False "finish":False
} }
self.replier.send(json.dumps(d)) self.replier.send(json.dumps(d))
else:
t2 = time()
d = {
"reqid":data.reqid,
"audio_file":wav,
"time_cost":t2 - t1
}
print(f'{d}')
return json.dumps(d)
t2 = time() t2 = time()
d = { d = {
"reqid":data.reqid, "reqid":data.reqid,
@ -103,6 +95,16 @@ class F5TTS:
"finish":True "finish":True
} }
return json.dumps(d) return json.dumps(d)
else:
audio_fn = self.inference(data.prompt)
t2 = time()
d = {
"reqid":data.reqid,
"audio_file":audio_fn,
"time_cost":t2 - t1
}
print(f'{d}')
return json.dumps(d)
def setup_voice(self): def setup_voice(self):
main_voice = {"ref_audio": self.config.ref_audio_fn, main_voice = {"ref_audio": self.config.ref_audio_fn,
@ -121,8 +123,7 @@ class F5TTS:
print("Ref_text:", voices[voice]["ref_text"]) print("Ref_text:", voices[voice]["ref_text"])
self.voices = voices self.voices = voices
def inference_stream(self, prompt):
def inference(self, prompt):
text_gen = prompt text_gen = prompt
remove_silence = False remove_silence = False
generated_audio_segments = [] generated_audio_segments = []
@ -146,6 +147,24 @@ class F5TTS:
print(f"Voice: {voice}, {self.model}") print(f"Voice: {voice}, {self.model}")
audio, final_sample_rate, spectragram = \ audio, final_sample_rate, spectragram = \
infer_process(ref_audio, ref_text, gen_text, self.model) infer_process(ref_audio, ref_text, gen_text, self.model)
yield {
'audio':audio,
'sample_rate':final_sample_rate,
'spectragram':spectragram,
'finish':False
}
yield {
'finish':True
}
def inference(self, prompt):
generated_audio_segments = []
remove_silence = self.config.remove_silence or False
final_sample_rate = 24000
for d in self.inference_stream(prompt):
if not d['finish']:
audio = d['audio']
final_sample_rate = d['sample_rate']
generated_audio_segments.append(audio) generated_audio_segments.append(audio)
if generated_audio_segments: if generated_audio_segments:
@ -158,13 +177,14 @@ class F5TTS:
remove_silence_for_generated_wav(f.name) remove_silence_for_generated_wav(f.name)
return fn return fn
if __name__ == '__main__': if __name__ == '__main__':
# workdir = os.getcwd() workdir = os.getcwd()
# config = getConfig(workdir) config = getConfig(workdir, {'workdir':workdir})
print(config.ref_audio_fn)
tts = F5TTS() tts = F5TTS()
print('here') print('here')
# tts.run() tts.run()
"""
while True: while True:
print('prompt:') print('prompt:')
p = input() p = input()
@ -173,7 +193,4 @@ if __name__ == '__main__':
f = tts.inference(p) f = tts.inference(p)
t2 = time() t2 = time()
print(f'{f}, cost {t2-t1} seconds') print(f'{f}, cost {t2-t1} seconds')
"""

BIN
samples/test2.m4a Normal file

Binary file not shown.

BIN
samples/ttt.m4a Normal file

Binary file not shown.

BIN
samples/ttt.wav Normal file

Binary file not shown.

View File

@ -1,3 +1,4 @@
import sys
import json import json
import os import os
@ -9,7 +10,7 @@ from appPublic.uniqueID import getID
zmq_url = "tcp://127.0.0.1:9999" zmq_url = "tcp://127.0.0.1:9999"
from time import time from time import time
class ASRClient: class F5TTSClient:
def __init__(self, zmq_url): def __init__(self, zmq_url):
self.zmq_url = zmq_url self.zmq_url = zmq_url
self.requester = ZmqRequester(self.zmq_url) self.requester = ZmqRequester(self.zmq_url)