bugfix
This commit is contained in:
parent
ffaa06c8fc
commit
0bc047429f
@ -5,8 +5,8 @@
|
||||
"remove_silence":false,
|
||||
"modelname":"F5-TTS",
|
||||
"device":"cuda:0",
|
||||
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav",
|
||||
"ref_text":"对,这就是我,万人敬仰的太乙真人。",
|
||||
"ref_audio_fn":"$[workdir]$/samples/ttt.wav",
|
||||
"ref_text":"快点吃饭,上课要迟到了。",
|
||||
"cross_fade_duration":0
|
||||
}
|
||||
|
||||
|
59
f5tts.py
59
f5tts.py
@ -72,30 +72,22 @@ class F5TTS:
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
|
||||
self.model = load_model(model_cls, model_cfg, ckpt_file,
|
||||
self.config.vocab_file).to(self.config.device)
|
||||
self.config.vocab_file)
|
||||
self.model = self.model.to(self.config.device)
|
||||
|
||||
def generate(self, d):
|
||||
msg= d.decode('utf-8')
|
||||
data = DictObject(**json.loads(msg))
|
||||
print(data)
|
||||
t1 = time()
|
||||
for wav in self.inference(data.prompt, stream=data.stream):
|
||||
if data.stream:
|
||||
for wav in self.inference_stream(data.prompt, stream=data.stream):
|
||||
d = {
|
||||
"reqid":data.reqid,
|
||||
"b64wave":b64str(wav),
|
||||
"finish":False
|
||||
}
|
||||
self.replier.send(json.dumps(d))
|
||||
else:
|
||||
t2 = time()
|
||||
d = {
|
||||
"reqid":data.reqid,
|
||||
"audio_file":wav,
|
||||
"time_cost":t2 - t1
|
||||
}
|
||||
print(f'{d}')
|
||||
return json.dumps(d)
|
||||
t2 = time()
|
||||
d = {
|
||||
"reqid":data.reqid,
|
||||
@ -103,6 +95,16 @@ class F5TTS:
|
||||
"finish":True
|
||||
}
|
||||
return json.dumps(d)
|
||||
else:
|
||||
audio_fn = self.inference(data.prompt)
|
||||
t2 = time()
|
||||
d = {
|
||||
"reqid":data.reqid,
|
||||
"audio_file":audio_fn,
|
||||
"time_cost":t2 - t1
|
||||
}
|
||||
print(f'{d}')
|
||||
return json.dumps(d)
|
||||
|
||||
def setup_voice(self):
|
||||
main_voice = {"ref_audio": self.config.ref_audio_fn,
|
||||
@ -121,8 +123,7 @@ class F5TTS:
|
||||
print("Ref_text:", voices[voice]["ref_text"])
|
||||
self.voices = voices
|
||||
|
||||
|
||||
def inference(self, prompt):
|
||||
def inference_stream(self, prompt):
|
||||
text_gen = prompt
|
||||
remove_silence = False
|
||||
generated_audio_segments = []
|
||||
@ -146,6 +147,24 @@ class F5TTS:
|
||||
print(f"Voice: {voice}, {self.model}")
|
||||
audio, final_sample_rate, spectragram = \
|
||||
infer_process(ref_audio, ref_text, gen_text, self.model)
|
||||
yield {
|
||||
'audio':audio,
|
||||
'sample_rate':final_sample_rate,
|
||||
'spectragram':spectragram,
|
||||
'finish':False
|
||||
}
|
||||
yield {
|
||||
'finish':True
|
||||
}
|
||||
|
||||
def inference(self, prompt):
|
||||
generated_audio_segments = []
|
||||
remove_silence = self.config.remove_silence or False
|
||||
final_sample_rate = 24000
|
||||
for d in self.inference_stream(prompt):
|
||||
if not d['finish']:
|
||||
audio = d['audio']
|
||||
final_sample_rate = d['sample_rate']
|
||||
generated_audio_segments.append(audio)
|
||||
|
||||
if generated_audio_segments:
|
||||
@ -158,13 +177,14 @@ class F5TTS:
|
||||
remove_silence_for_generated_wav(f.name)
|
||||
return fn
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# workdir = os.getcwd()
|
||||
# config = getConfig(workdir)
|
||||
workdir = os.getcwd()
|
||||
config = getConfig(workdir, {'workdir':workdir})
|
||||
print(config.ref_audio_fn)
|
||||
tts = F5TTS()
|
||||
print('here')
|
||||
# tts.run()
|
||||
tts.run()
|
||||
"""
|
||||
while True:
|
||||
print('prompt:')
|
||||
p = input()
|
||||
@ -173,7 +193,4 @@ if __name__ == '__main__':
|
||||
f = tts.inference(p)
|
||||
t2 = time()
|
||||
print(f'{f}, cost {t2-t1} seconds')
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
BIN
samples/test2.m4a
Normal file
BIN
samples/test2.m4a
Normal file
Binary file not shown.
BIN
samples/ttt.m4a
Normal file
BIN
samples/ttt.m4a
Normal file
Binary file not shown.
BIN
samples/ttt.wav
Normal file
BIN
samples/ttt.wav
Normal file
Binary file not shown.
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
|
||||
@ -9,7 +10,7 @@ from appPublic.uniqueID import getID
|
||||
zmq_url = "tcp://127.0.0.1:9999"
|
||||
from time import time
|
||||
|
||||
class ASRClient:
|
||||
class F5TTSClient:
|
||||
def __init__(self, zmq_url):
|
||||
self.zmq_url = zmq_url
|
||||
self.requester = ZmqRequester(self.zmq_url)
|
||||
|
Loading…
Reference in New Issue
Block a user