bugfix
This commit is contained in:
parent
ffaa06c8fc
commit
0bc047429f
@ -5,8 +5,8 @@
|
|||||||
"remove_silence":false,
|
"remove_silence":false,
|
||||||
"modelname":"F5-TTS",
|
"modelname":"F5-TTS",
|
||||||
"device":"cuda:0",
|
"device":"cuda:0",
|
||||||
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav",
|
"ref_audio_fn":"$[workdir]$/samples/ttt.wav",
|
||||||
"ref_text":"对,这就是我,万人敬仰的太乙真人。",
|
"ref_text":"快点吃饭,上课要迟到了。",
|
||||||
"cross_fade_duration":0
|
"cross_fade_duration":0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
77
f5tts.py
77
f5tts.py
@ -72,37 +72,39 @@ class F5TTS:
|
|||||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||||
|
|
||||||
self.model = load_model(model_cls, model_cfg, ckpt_file,
|
self.model = load_model(model_cls, model_cfg, ckpt_file,
|
||||||
self.config.vocab_file).to(self.config.device)
|
self.config.vocab_file)
|
||||||
|
self.model = self.model.to(self.config.device)
|
||||||
|
|
||||||
def generate(self, d):
|
def generate(self, d):
|
||||||
msg= d.decode('utf-8')
|
msg= d.decode('utf-8')
|
||||||
data = DictObject(**json.loads(msg))
|
data = DictObject(**json.loads(msg))
|
||||||
print(data)
|
print(data)
|
||||||
t1 = time()
|
t1 = time()
|
||||||
for wav in self.inference(data.prompt, stream=data.stream):
|
if data.stream:
|
||||||
if data.stream:
|
for wav in self.inference_stream(data.prompt, stream=data.stream):
|
||||||
d = {
|
d = {
|
||||||
"reqid":data.reqid,
|
"reqid":data.reqid,
|
||||||
"b64wave":b64str(wav),
|
"b64wave":b64str(wav),
|
||||||
"finish":False
|
"finish":False
|
||||||
}
|
}
|
||||||
self.replier.send(json.dumps(d))
|
self.replier.send(json.dumps(d))
|
||||||
else:
|
t2 = time()
|
||||||
t2 = time()
|
d = {
|
||||||
d = {
|
"reqid":data.reqid,
|
||||||
"reqid":data.reqid,
|
"time_cost":t2 - t1,
|
||||||
"audio_file":wav,
|
"finish":True
|
||||||
"time_cost":t2 - t1
|
}
|
||||||
}
|
return json.dumps(d)
|
||||||
print(f'{d}')
|
else:
|
||||||
return json.dumps(d)
|
audio_fn = self.inference(data.prompt)
|
||||||
t2 = time()
|
t2 = time()
|
||||||
d = {
|
d = {
|
||||||
"reqid":data.reqid,
|
"reqid":data.reqid,
|
||||||
"time_cost":t2 - t1,
|
"audio_file":audio_fn,
|
||||||
"finish":True
|
"time_cost":t2 - t1
|
||||||
}
|
}
|
||||||
return json.dumps(d)
|
print(f'{d}')
|
||||||
|
return json.dumps(d)
|
||||||
|
|
||||||
def setup_voice(self):
|
def setup_voice(self):
|
||||||
main_voice = {"ref_audio": self.config.ref_audio_fn,
|
main_voice = {"ref_audio": self.config.ref_audio_fn,
|
||||||
@ -121,8 +123,7 @@ class F5TTS:
|
|||||||
print("Ref_text:", voices[voice]["ref_text"])
|
print("Ref_text:", voices[voice]["ref_text"])
|
||||||
self.voices = voices
|
self.voices = voices
|
||||||
|
|
||||||
|
def inference_stream(self, prompt):
|
||||||
def inference(self, prompt):
|
|
||||||
text_gen = prompt
|
text_gen = prompt
|
||||||
remove_silence = False
|
remove_silence = False
|
||||||
generated_audio_segments = []
|
generated_audio_segments = []
|
||||||
@ -146,7 +147,25 @@ class F5TTS:
|
|||||||
print(f"Voice: {voice}, {self.model}")
|
print(f"Voice: {voice}, {self.model}")
|
||||||
audio, final_sample_rate, spectragram = \
|
audio, final_sample_rate, spectragram = \
|
||||||
infer_process(ref_audio, ref_text, gen_text, self.model)
|
infer_process(ref_audio, ref_text, gen_text, self.model)
|
||||||
generated_audio_segments.append(audio)
|
yield {
|
||||||
|
'audio':audio,
|
||||||
|
'sample_rate':final_sample_rate,
|
||||||
|
'spectragram':spectragram,
|
||||||
|
'finish':False
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
'finish':True
|
||||||
|
}
|
||||||
|
|
||||||
|
def inference(self, prompt):
|
||||||
|
generated_audio_segments = []
|
||||||
|
remove_silence = self.config.remove_silence or False
|
||||||
|
final_sample_rate = 24000
|
||||||
|
for d in self.inference_stream(prompt):
|
||||||
|
if not d['finish']:
|
||||||
|
audio = d['audio']
|
||||||
|
final_sample_rate = d['sample_rate']
|
||||||
|
generated_audio_segments.append(audio)
|
||||||
|
|
||||||
if generated_audio_segments:
|
if generated_audio_segments:
|
||||||
final_wave = np.concatenate(generated_audio_segments)
|
final_wave = np.concatenate(generated_audio_segments)
|
||||||
@ -158,13 +177,14 @@ class F5TTS:
|
|||||||
remove_silence_for_generated_wav(f.name)
|
remove_silence_for_generated_wav(f.name)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# workdir = os.getcwd()
|
workdir = os.getcwd()
|
||||||
# config = getConfig(workdir)
|
config = getConfig(workdir, {'workdir':workdir})
|
||||||
|
print(config.ref_audio_fn)
|
||||||
tts = F5TTS()
|
tts = F5TTS()
|
||||||
print('here')
|
print('here')
|
||||||
# tts.run()
|
tts.run()
|
||||||
|
"""
|
||||||
while True:
|
while True:
|
||||||
print('prompt:')
|
print('prompt:')
|
||||||
p = input()
|
p = input()
|
||||||
@ -173,7 +193,4 @@ if __name__ == '__main__':
|
|||||||
f = tts.inference(p)
|
f = tts.inference(p)
|
||||||
t2 = time()
|
t2 = time()
|
||||||
print(f'{f}, cost {t2-t1} seconds')
|
print(f'{f}, cost {t2-t1} seconds')
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
BIN
samples/test2.m4a
Normal file
BIN
samples/test2.m4a
Normal file
Binary file not shown.
BIN
samples/ttt.m4a
Normal file
BIN
samples/ttt.m4a
Normal file
Binary file not shown.
BIN
samples/ttt.wav
Normal file
BIN
samples/ttt.wav
Normal file
Binary file not shown.
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -9,7 +10,7 @@ from appPublic.uniqueID import getID
|
|||||||
zmq_url = "tcp://127.0.0.1:9999"
|
zmq_url = "tcp://127.0.0.1:9999"
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
class ASRClient:
|
class F5TTSClient:
|
||||||
def __init__(self, zmq_url):
|
def __init__(self, zmq_url):
|
||||||
self.zmq_url = zmq_url
|
self.zmq_url = zmq_url
|
||||||
self.requester = ZmqRequester(self.zmq_url)
|
self.requester = ZmqRequester(self.zmq_url)
|
||||||
|
Loading…
Reference in New Issue
Block a user