From e80c8a33213b57a382a324e062e7e0ee99f64ce3 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Tue, 22 Oct 2024 19:13:23 +0800 Subject: [PATCH] bugfix --- f5tts.py | 204 +++++++++++++++++++++---------------------------------- run.sh | 2 +- 2 files changed, 79 insertions(+), 127 deletions(-) diff --git a/f5tts.py b/f5tts.py index 0c97051..32b8cf6 100644 --- a/f5tts.py +++ b/f5tts.py @@ -1,20 +1,23 @@ -from time import time -import torch -from pathlib import Path +import sys +sys.path.append('./F5TTS') +import argparse import codecs import re +from pathlib import Path + import numpy as np import soundfile as sf -import torch -import torchaudio +import tomli from cached_path import cached_path -from einops import rearrange -from vocos import Vocos -from transformers import pipeline -from F5_TTS.model import CFM, DiT, MMDiT, UNetT -from F5_TTS.model.utils import (convert_char_to_pinyin, get_tokenizer, - load_checkpoint, save_spectrogram) +from model import DiT, UNetT +from model.utils_infer import ( + load_vocoder, + load_model, + preprocess_ref_audio_text, + infer_process, + remove_silence_for_generated_wav, +) import os import json @@ -31,32 +34,51 @@ ode_method = "euler" sway_sampling_coef = -1.0 speed = 1.0 -def chunk_text(text, max_chars=135): - """ - Splits the input text into chunks, each with a maximum number of characters. - Args: - text (str): The text to be split. - max_chars (int): The maximum number of characters per chunk. - Returns: - List[str]: A list of text chunks. - """ - chunks = [] - current_chunk = "" - # Split the text into sentences based on punctuation followed by whitespace - sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text) +def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence): + main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} + if "voices" not in config: + voices = {"main": main_voice} + else: + voices = config["voices"] + voices["main"] = main_voice + for voice in voices: + voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text( + voices[voice]["ref_audio"], voices[voice]["ref_text"] + ) + print("Voice:", voice) + print("Ref_audio:", voices[voice]["ref_audio"]) + print("Ref_text:", voices[voice]["ref_text"]) - for sentence in sentences: - if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars: - current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence - else: - if current_chunk: - chunks.append(current_chunk.strip()) - current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence + generated_audio_segments = [] + reg1 = r"(?=\[\w+\])" + chunks = re.split(reg1, text_gen) + reg2 = r"\[(\w+)\]" + for text in chunks: + match = re.match(reg2, text) + if match: + voice = match[1] + else: + print("No voice tag found, using main.") + voice = "main" + if voice not in voices: + print(f"Voice {voice} not found, using main.") + voice = "main" + text = re.sub(reg2, "", text) + gen_text = text.strip() + ref_audio = voices[voice]["ref_audio"] + ref_text = voices[voice]["ref_text"] + print(f"Voice: {voice}") + audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj) + generated_audio_segments.append(audio) - if current_chunk: - chunks.append(current_chunk.strip()) - - return chunks + if generated_audio_segments: + final_wave = np.concatenate(generated_audio_segments) + with open(wave_path, "wb") as f: + sf.write(f.name, final_wave, final_sample_rate) + # Remove silence + if remove_silence: + remove_silence_for_generated_wav(f.name) + print(f.name) class F5TTS: def __init__(self): @@ -65,6 +87,7 @@ class F5TTS: self.remove_silence = config.remove_silence self.modelname = config.modelname self.ref_audio_fn = config.ref_audio_fn + self.load_vocoder_from_local = config.is_local or True self.zmq_url = config.zmq_url self.ref_text = config.ref_text self.device = config.device @@ -72,18 +95,7 @@ class F5TTS: self.gen_ref_audio() self.gen_ref_text() self.replier = ZmqReplier(self.zmq_url, self.generate) - try: - print(f"Load vocos from local path {vocos_local_path}") - vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") - state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device) - vocos.load_state_dict(state_dict) - vocos.eval() - self.vocos = vocos - except: - print("Donwload Vocos from huggingface charactr/vocos-mel-24khz") - vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") - self.vocos = vocos - + self.vocos = load_vocoder(is_local=is_local, local_path="../checkpoints/charactr/vocos-mel-24khz") self.F5TTS_model_cfg = dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 ) @@ -123,93 +135,33 @@ class F5TTS: ref_text += ". " self.ref_text = ref_text - def _load_model(self, repo_name, exp_name, model_cls, model_cfg, ckpt_step): - ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors - if not Path(ckpt_path).exists(): - ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin") - model = CFM( - transformer=model_cls( - **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels - ), - mel_spec_kwargs=dict( - target_sample_rate=self.sample_rate, - n_mel_channels=n_mel_channels, - hop_length=hop_length, - ), - odeint_kwargs=dict( - method="euler" - ), - vocab_char_map=vocab_char_map, - ).to(self.device) + def load_model(self): + # load models + if self.modelname == "F5-TTS": + model_cls = DiT + model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) + if ckpt_file == "": + repo_name = "F5-TTS" + exp_name = "F5TTS_Base" + ckpt_step = 1200000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - model = load_checkpoint(model, ckpt_path, self.device, use_ema = True) + elif self.modelname == "E2-TTS": + model_cls = UNetT + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + if ckpt_file == "": + repo_name = "E2-TTS" + exp_name = "E2TTS_Base" + ckpt_step = 1200000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) - return model - - def load_model(self, model): - if model == 'F5-TTS': - ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000) - return ret - if model == 'E2-TTS': - return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000) + self.model = load_model(model_cls, model_cfg, ckpt_file, vocab_file) def split_text(self, text): max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate)) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) print('ref_text', ref_text) - def inference(self, prompt, stream=False): - generated_waves = [] - max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate)) - gen_text_batches = chunk_text(prompt, max_chars=max_chars) - for gen_text in gen_text_batches: - # Prepare the text - text_list = [self.ref_text + gen_text] - final_text_list = convert_char_to_pinyin(text_list) - - # Calculate duration - ref_audio_len = self.ref_audio.shape[-1] // hop_length - zh_pause_punc = r"。,、;:?!" - ref_text_len = len(self.ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, self.ref_text)) - gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text)) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) - - # inference - print(f'{self.device=}....') - with torch.inference_mode(): - generated, _ = self.model.sample( - cond=self.ref_audio, - text=final_text_list, - duration=duration, - steps=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - ) - generated = generated[:, ref_audio_len:, :] - generated_mel_spec = rearrange(generated, "1 n d -> 1 d n") - generated_wave = vocos.decode(generated_mel_spec.cpu()) - if rms < target_rms: - generated_wave = generated_wave * rms / target_rms - if stream: - yield genreated - else: - # wav -> numpy - generated_wave = generated_wave.squeeze().cpu().numpy() - generated_waves.append(generated_wave) - if stream: - print(f'here ........{stream}') - return - if self.cross_fade_duration <= 0: - # Simply concatenate - final_wave = np.concatenate(generated_waves) - else: - final_wave = self.cross_fade_wave(generated_waves) - fn = self.write_wave(final_wave) - print(f'here ........{stream}, {fn=}') - yield fn - return - def cross_fade_wave(self, waves): final_wave = generated_waves[0] for i in range(1, len(generated_waves)): diff --git a/run.sh b/run.sh index 418acfa..da0d3d4 100755 --- a/run.sh +++ b/run.sh @@ -1,4 +1,4 @@ #!/bin/sh -r=$HOME/ve/f5/bin/python +r=$HOME/ve/f5tts/bin/python $r $*