commit 60fab4e93772d0bce3cdfaa220b104cfdcc27582 Author: yumoqing Date: Mon Oct 21 18:29:57 2024 +0800 first commit diff --git a/conf/config.json b/conf/config.json new file mode 100644 index 0000000..c59ec99 --- /dev/null +++ b/conf/config.json @@ -0,0 +1,10 @@ +{ + "zmq_url" : "tcp://127.0.0.1:10003", + "sample_rate":16000, + "remove_silence":false, + "modelname":"F5-TTS", + "ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav", + "ref_text":"对,这就是我,万人敬仰的太乙真人。", + "cross_fade_duration":0 +} + diff --git a/f5tts.py b/f5tts.py new file mode 100644 index 0000000..fb16a09 --- /dev/null +++ b/f5tts.py @@ -0,0 +1,272 @@ +import time +from pathlib import Path +import codecs +import re +import numpy as np +import soundfile as sf +import torch +import torchaudio +from cached_path import cached_path +from einops import rearrange + +from vocos import Vocos +from transformers import pipeline +from F5_TTS.model import CFM, DiT, MMDiT, UNetT +from F5_TTS.model.utils import (convert_char_to_pinyin, get_tokenizer, + load_checkpoint, save_spectrogram) + +import os +import json +from appPublic.dictObject import DictObject +from appPublic.zmq_reqrep import ZmqReplier +from appPublic.jsonConfig import getConfig + +n_mel_channels = 100 +hop_length = 256 +target_rms = 0.1 +nfe_step = 32 # 16, 32 +cfg_strength = 2.0 +ode_method = "euler" +sway_sampling_coef = -1.0 +speed = 1.0 + +F5TTS_model_cfg = dict( + dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 +) +E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + +def chunk_text(text, max_chars=135): + """ + Splits the input text into chunks, each with a maximum number of characters. + Args: + text (str): The text to be split. + max_chars (int): The maximum number of characters per chunk. + Returns: + List[str]: A list of text chunks. + """ + chunks = [] + current_chunk = "" + # Split the text into sentences based on punctuation followed by whitespace + sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text) + + for sentence in sentences: + if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars: + current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence + else: + if current_chunk: + chunks.append(current_chunk.strip()) + current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence + + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + +class F5TTS: + def __init__(self): + config = getConfig() + self.sample_rate = config.sample_rate + self.remove_silence = config.remove_silence + self.modelname = config.modelname + self.ref_audio_fn = config.ref_audio_fn + self.ref_text = config.ref_text + self.model= self.load_model(self.modelname) + self.cross_fade_duration = config.cross_fade_duration + self.gen_ref_audio() + self.gen_ref_text() + self.replier = ZmqReplier(self.zmq_url, self.generate) + try: + print(f"Load vocos from local path {vocos_local_path}") + vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") + state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) + vocos.load_state_dict(state_dict) + vocos.eval() + self.vocos = vocos + except: + print("Donwload Vocos from huggingface charactr/vocos-mel-24khz") + vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") + self.vocos = vocos + + + def gen_ref_audio(self): + """ + gen ref_audio + """ + audio, sr = torchaudio.load(ref_audio) + if audio.shape[0] > 1: + audio = torch.mean(audio, dim=0, keepdim=True) + rms = torch.sqrt(torch.mean(torch.square(audio))) + if rms < target_rms: + audio = audio * target_rms / rms + if sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(sr, target_sample_rate) + audio = resampler(audio) + self.ref_audio = audio + + def run(self): + print(f'running {self.zmq_url}') + self.replier._run() + print('ended ...') + + def gen_ref_text(self): + """ + """ + # Add the functionality to ensure it ends with ". " + ref_text = self.ref_text + if not ref_text.endswith(". ") and not ref_text.endswith("。"): + if ref_text.endswith("."): + ref_text += " " + else: + ref_text += ". " + self.ref_text = ref_text + + def _load_model(self, repo_name, exp_name, model_cls, model_cfg, ckpt_step): + ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors + if not Path(ckpt_path).exists(): + ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) + vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin") + model = CFM( + transformer=model_cls( + **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), + mel_spec_kwargs=dict( + target_sample_rate=self.sample_rate, + n_mel_channels=n_mel_channels, + hop_length=hop_length, + ), + odeint_kwargs=dict( + method="euler" + ), + vocab_char_map=vocab_char_map, + ).to(device) + + model = load_checkpoint(model, ckpt_path, device, use_ema = True) + + return model + + def load_model(self, model): + if model == 'F5-TTS': + ret = self._load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000) + return ret + if model == 'E2-TTS': + return self._load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000) + + def split_text(self, text): + max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate)) + gen_text_batches = chunk_text(gen_text, max_chars=max_chars) + print('ref_text', ref_text) + + def inference(self, prmpt, stream=False): + generated_waves = [] + max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate)) + gen_text_batches = chunk_text(prompt, max_chars=max_chars) + for gen_text in gen_text_Batches: + # Prepare the text + text_list = [self.ref_text + gen_text] + final_text_list = convert_char_to_pinyin(text_list) + + # Calculate duration + ref_audio_len = self.ref_audio.shape[-1] // hop_length + zh_pause_punc = r"。,、;:?!" + ref_text_len = len(self.ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, self.ref_text)) + gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text)) + duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) + + # inference + with torch.inference_mode(): + generated, _ = self.model.sample( + cond=self.ref_audio, + text=final_text_list, + duration=duration, + steps=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + ) + generated = generated[:, ref_audio_len:, :] + generated_mel_spec = rearrange(generated, "1 n d -> 1 d n") + generated_wave = vocos.decode(generated_mel_spec.cpu()) + if rms < target_rms: + generated_wave = generated_wave * rms / target_rms + if stream: + yield genreated + else: + # wav -> numpy + generated_wave = generated_wave.squeeze().cpu().numpy() + generated_waves.append(generated_wave) + if stream: + return + if self.cross_fade_duration <= 0: + # Simply concatenate + final_wave = np.concatenate(generated_waves) + else: + final_wave = self.cross_fade_wave(generated_waves) + fn = self.write_wave(final_wave) + return fn + + def cross_fade_wave(self, waves): + final_wave = generated_waves[0] + for i in range(1, len(generated_waves)): + prev_wave = final_wave + next_wave = generated_waves[i] + + # Calculate cross-fade samples, ensuring it does not exceed wave lengths + cross_fade_samples = int(self.cross_fade_duration * self.sample_rate) + cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) + + if cross_fade_samples <= 0: + # No overlap possible, concatenate + final_wave = np.concatenate([prev_wave, next_wave]) + continue + + # Overlapping parts + prev_overlap = prev_wave[-cross_fade_samples:] + next_overlap = next_wave[:cross_fade_samples] + + # Fade out and fade in + fade_out = np.linspace(1, 0, cross_fade_samples) + fade_in = np.linspace(0, 1, cross_fade_samples) + + # Cross-faded overlap + cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in + + # Combine + new_wave = np.concatenate([ + prev_wave[:-cross_fade_samples], + cross_faded_overlap, + next_wave[cross_fade_samples:] + ]) + + final_wave = new_wave + return final_wave + + def write_wave(wave): + fn = temp_file(suffix='.wav') + sf.write(fn, wave, self.sample_rate) + return fn + + def generate(self, d): + msg= d.decode('utf-8') + data = DictObject(**json.loads(msg)) + print(data) + t1 = time() + f = self.inference(data.prompt) + t2 = time() + d = { + "audio_file":f, + "time_cost":t2 - t1 + } + print(f'{d}') + return json.dumps(d) + + +if __name__ == '__main__': + workdir = os.getcwd() + config = getConfig(workdir) + print(f'{config=}') + tts = F5TTS() + print('here') + tts.run() + + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2409858 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +git+https://github.com/SWivid/F5-TTS.git +git+https://git.kaiyuancloud.cn/yumoqing/apppublic diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..418acfa --- /dev/null +++ b/run.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +r=$HOME/ve/f5/bin/python +$r $*