This commit is contained in:
yumoqing 2024-10-22 19:13:23 +08:00
parent d8de885311
commit e80c8a3321
2 changed files with 79 additions and 127 deletions

202
f5tts.py
View File

@ -1,20 +1,23 @@
from time import time import sys
import torch sys.path.append('./F5TTS')
from pathlib import Path import argparse
import codecs import codecs
import re import re
from pathlib import Path
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
import torch import tomli
import torchaudio
from cached_path import cached_path from cached_path import cached_path
from einops import rearrange
from vocos import Vocos from model import DiT, UNetT
from transformers import pipeline from model.utils_infer import (
from F5_TTS.model import CFM, DiT, MMDiT, UNetT load_vocoder,
from F5_TTS.model.utils import (convert_char_to_pinyin, get_tokenizer, load_model,
load_checkpoint, save_spectrogram) preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
)
import os import os
import json import json
@ -31,32 +34,51 @@ ode_method = "euler"
sway_sampling_coef = -1.0 sway_sampling_coef = -1.0
speed = 1.0 speed = 1.0
def chunk_text(text, max_chars=135): def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
""" main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
Splits the input text into chunks, each with a maximum number of characters. if "voices" not in config:
Args: voices = {"main": main_voice}
text (str): The text to be split.
max_chars (int): The maximum number of characters per chunk.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
for sentence in sentences:
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
else: else:
if current_chunk: voices = config["voices"]
chunks.append(current_chunk.strip()) voices["main"] = main_voice
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence for voice in voices:
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])
if current_chunk: generated_audio_segments = []
chunks.append(current_chunk.strip()) reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, text_gen)
reg2 = r"\[(\w+)\]"
for text in chunks:
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
generated_audio_segments.append(audio)
return chunks if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
with open(wave_path, "wb") as f:
sf.write(f.name, final_wave, final_sample_rate)
# Remove silence
if remove_silence:
remove_silence_for_generated_wav(f.name)
print(f.name)
class F5TTS: class F5TTS:
def __init__(self): def __init__(self):
@ -65,6 +87,7 @@ class F5TTS:
self.remove_silence = config.remove_silence self.remove_silence = config.remove_silence
self.modelname = config.modelname self.modelname = config.modelname
self.ref_audio_fn = config.ref_audio_fn self.ref_audio_fn = config.ref_audio_fn
self.load_vocoder_from_local = config.is_local or True
self.zmq_url = config.zmq_url self.zmq_url = config.zmq_url
self.ref_text = config.ref_text self.ref_text = config.ref_text
self.device = config.device self.device = config.device
@ -72,18 +95,7 @@ class F5TTS:
self.gen_ref_audio() self.gen_ref_audio()
self.gen_ref_text() self.gen_ref_text()
self.replier = ZmqReplier(self.zmq_url, self.generate) self.replier = ZmqReplier(self.zmq_url, self.generate)
try: self.vocos = load_vocoder(is_local=is_local, local_path="../checkpoints/charactr/vocos-mel-24khz")
print(f"Load vocos from local path {vocos_local_path}")
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device)
vocos.load_state_dict(state_dict)
vocos.eval()
self.vocos = vocos
except:
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
self.vocos = vocos
self.F5TTS_model_cfg = dict( self.F5TTS_model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
) )
@ -123,93 +135,33 @@ class F5TTS:
ref_text += ". " ref_text += ". "
self.ref_text = ref_text self.ref_text = ref_text
def _load_model(self, repo_name, exp_name, model_cls, model_cfg, ckpt_step): def load_model(self):
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors # load models
if not Path(ckpt_path).exists(): if self.modelname == "F5-TTS":
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) model_cls = DiT
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin") model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model = CFM( if ckpt_file == "":
transformer=model_cls( repo_name = "F5-TTS"
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels exp_name = "F5TTS_Base"
), ckpt_step = 1200000
mel_spec_kwargs=dict( ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
target_sample_rate=self.sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
odeint_kwargs=dict(
method="euler"
),
vocab_char_map=vocab_char_map,
).to(self.device)
model = load_checkpoint(model, ckpt_path, self.device, use_ema = True) elif self.modelname == "E2-TTS":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if ckpt_file == "":
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
return model self.model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
def load_model(self, model):
if model == 'F5-TTS':
ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000)
return ret
if model == 'E2-TTS':
return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000)
def split_text(self, text): def split_text(self, text):
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate)) max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
gen_text_batches = chunk_text(gen_text, max_chars=max_chars) gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
print('ref_text', ref_text) print('ref_text', ref_text)
def inference(self, prompt, stream=False):
generated_waves = []
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
for gen_text in gen_text_batches:
# Prepare the text
text_list = [self.ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
# Calculate duration
ref_audio_len = self.ref_audio.shape[-1] // hop_length
zh_pause_punc = r"。,、;:?!"
ref_text_len = len(self.ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, self.ref_text))
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
# inference
print(f'{self.device=}....')
with torch.inference_mode():
generated, _ = self.model.sample(
cond=self.ref_audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
generated_wave = vocos.decode(generated_mel_spec.cpu())
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
if stream:
yield genreated
else:
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
if stream:
print(f'here ........{stream}')
return
if self.cross_fade_duration <= 0:
# Simply concatenate
final_wave = np.concatenate(generated_waves)
else:
final_wave = self.cross_fade_wave(generated_waves)
fn = self.write_wave(final_wave)
print(f'here ........{stream}, {fn=}')
yield fn
return
def cross_fade_wave(self, waves): def cross_fade_wave(self, waves):
final_wave = generated_waves[0] final_wave = generated_waves[0]
for i in range(1, len(generated_waves)): for i in range(1, len(generated_waves)):

2
run.sh
View File

@ -1,4 +1,4 @@
#!/bin/sh #!/bin/sh
r=$HOME/ve/f5/bin/python r=$HOME/ve/f5tts/bin/python
$r $* $r $*