import sys sys.path.append('./F5TTS') import argparse import codecs import re from pathlib import Path import numpy as np import soundfile as sf import tomli from cached_path import cached_path from model import DiT, UNetT from model.utils_infer import ( load_vocoder, load_model, preprocess_ref_audio_text, infer_process, remove_silence_for_generated_wav, ) import os import json from time import time from appPublic.dictObject import DictObject from appPublic.zmq_reqrep import ZmqReplier from appPublic.folderUtils import temp_file from appPublic.jsonConfig import getConfig n_mel_channels = 100 hop_length = 256 target_rms = 0.1 nfe_step = 32 # 16, 32 cfg_strength = 2.0 ode_method = "euler" sway_sampling_coef = -1.0 speed = 1.0 class F5TTS: def __init__(self): self.config = getConfig() self.zmq_url = self.config.zmq_url self.replier = ZmqReplier(self.config.zmq_url, self.generate) # self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz") self.load_model() self.setup_voice() def run(self): print(f'running {self.zmq_url}') self.replier._run() print('ended ...') def load_model(self): # load models ckpt_file = '' if self.config.modelname == "F5-TTS": model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) if ckpt_file == "": repo_name = "F5-TTS" exp_name = "F5TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) elif self.config.modelname == "E2-TTS": model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if ckpt_file == "": repo_name = "E2-TTS" exp_name = "E2TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) self.model = load_model(model_cls, model_cfg, ckpt_file, self.config.vocab_file) self.model = self.model.to(self.config.device) def generate(self, d): msg= d.decode('utf-8') data = DictObject(**json.loads(msg)) print(data) t1 = time() if data.stream: for wav in self.inference_stream(data.prompt, stream=data.stream): d = { "reqid":data.reqid, "b64wave":b64str(wav), "finish":False } self.replier.send(json.dumps(d)) t2 = time() d = { "reqid":data.reqid, "time_cost":t2 - t1, "finish":True } return json.dumps(d) else: audio_fn = self.inference(data.prompt) t2 = time() d = { "reqid":data.reqid, "audio_file":audio_fn, "time_cost":t2 - t1 } print(f'{d}') return json.dumps(d) def setup_voice(self): main_voice = {"ref_audio": self.config.ref_audio_fn, "ref_text": self.config.ref_text} if "voices" not in self.config: voices = {"main": main_voice} else: voices = self.config["voices"] voices["main"] = main_voice for voice in voices: voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( voices[voice]["ref_audio"], voices[voice]["ref_text"] ) print("Voice:", voice) print("Ref_audio:", voices[voice]["ref_audio"]) print("Ref_text:", voices[voice]["ref_text"]) self.voices = voices def inference_stream(self, prompt): text_gen = prompt remove_silence = False generated_audio_segments = [] reg1 = r"(?=\[\w+\])" chunks = re.split(reg1, text_gen) reg2 = r"\[(\w+)\]" for text in chunks: match = re.match(reg2, text) if match: voice = match[1] else: print("No voice tag found, using main.") voice = "main" if voice not in self.voices: print(f"Voice {voice} not found, using main.") voice = "main" text = re.sub(reg2, "", text) gen_text = text.strip() ref_audio = self.voices[voice]["ref_audio"] ref_text = self.voices[voice]["ref_text"] print(f"Voice: {voice}, {self.model}") audio, final_sample_rate, spectragram = \ infer_process(ref_audio, ref_text, gen_text, self.model) yield { 'audio':audio, 'sample_rate':final_sample_rate, 'spectragram':spectragram, 'finish':False } yield { 'finish':True } def inference(self, prompt): generated_audio_segments = [] remove_silence = self.config.remove_silence or False final_sample_rate = 24000 for d in self.inference_stream(prompt): if not d['finish']: audio = d['audio'] final_sample_rate = d['sample_rate'] generated_audio_segments.append(audio) if generated_audio_segments: final_wave = np.concatenate(generated_audio_segments) fn = temp_file(suffix='.wav') with open(fn, "wb") as f: sf.write(f.name, final_wave, final_sample_rate) # Remove silence if remove_silence: remove_silence_for_generated_wav(f.name) return fn if __name__ == '__main__': workdir = os.getcwd() config = getConfig(workdir, {'workdir':workdir}) print(config.ref_audio_fn) tts = F5TTS() print('here') tts.run() """ while True: print('prompt:') p = input() if p != '': t1 = time() f = tts.inference(p) t2 = time() print(f'{f}, cost {t2-t1} seconds') """