import sys import argparse import codecs import re from pathlib import Path import numpy as np import soundfile as sf # import tomli from cached_path import cached_path from f5_tts.model import DiT, UNetT from f5_tts.model.utils_infer import ( load_vocoder, load_model, preprocess_ref_audio_text, infer_process, remove_silence_for_generated_wav, ) import os import json from time import time from appPublic.dictObject import DictObject from appPublic.folderUtils import temp_file from appPublic.jsonConfig import getConfig from appPublic.worker import awaitify from ahserver.webapp import webapp from ahserver.serverEnv import ServerEnv n_mel_channels = 100 hop_length = 256 target_rms = 0.1 nfe_step = 32 # 16, 32 cfg_strength = 2.0 ode_method = "euler" sway_sampling_coef = -1.0 speed = 1.0 class F5TTS: def __init__(self): self.config = getConfig() # self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz") self.load_model() self.setup_voice() def load_model(self): # load models ckpt_file = '' if self.config.modelname == "F5-TTS": model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) if ckpt_file == "": repo_name = "F5-TTS" exp_name = "F5TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) elif self.config.modelname == "E2-TTS": model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if ckpt_file == "": repo_name = "E2-TTS" exp_name = "E2TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) self.model = load_model(model_cls, model_cfg, ckpt_file, self.config.vocab_file) self.model = self.model.to(self.config.device) def setup_voice(self): main_voice = {"ref_audio": self.config.ref_audio_fn, "ref_text": self.config.ref_text} if "voices" not in self.config: voices = {"main": main_voice} else: voices = self.config["voices"] voices["main"] = main_voice for voice in voices: voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( voices[voice]["ref_audio"], voices[voice]["ref_text"] ) print("Voice:", voice) print("Ref_audio:", voices[voice]["ref_audio"]) print("Ref_text:", voices[voice]["ref_text"]) self.voices = voices def _inference_stream(self, prompt): text_gen = prompt remove_silence = False generated_audio_segments = [] reg1 = r"(?=\[\w+\])" chunks = re.split(reg1, text_gen) reg2 = r"\[(\w+)\]" for text in chunks: match = re.match(reg2, text) if match: voice = match[1] else: print("No voice tag found, using main.") voice = "main" if voice not in self.voices: print(f"Voice {voice} not found, using main.") voice = "main" text = re.sub(reg2, "", text) gen_text = text.strip() ref_audio = self.voices[voice]["ref_audio"] ref_text = self.voices[voice]["ref_text"] print(f"Voice: {voice}, {self.model}") audio, final_sample_rate, spectragram = \ infer_process(ref_audio, ref_text, gen_text, self.model) yield { 'audio':audio, 'sample_rate':final_sample_rate, 'spectragram':spectragram, 'finish':False } yield { 'finish':True } def _inference(self, prompt): generated_audio_segments = [] remove_silence = self.config.remove_silence or False final_sample_rate = 24000 for d in self._inference_stream(prompt): if not d['finish']: audio = d['audio'] final_sample_rate = d['sample_rate'] generated_audio_segments.append(audio) if generated_audio_segments: final_wave = np.concatenate(generated_audio_segments) fn = temp_file(suffix='.wav') with open(fn, "wb") as f: sf.write(f.name, final_wave, final_sample_rate) # Remove silence if remove_silence: remove_silence_for_generated_wav(f.name) return fn def init(): g = ServerEnv() f5 = F5TTS() g.infer_stream = awaitify(f5._inference_stream) g.infer = awaitify(f5._inference) if __name__ == '__main__': webapp(init)