This commit is contained in:
yumoqing 2024-10-22 19:13:23 +08:00
parent d8de885311
commit e80c8a3321
2 changed files with 79 additions and 127 deletions

202
f5tts.py
View File

@ -1,20 +1,23 @@
from time import time
import torch
from pathlib import Path
import sys
sys.path.append('./F5TTS')
import argparse
import codecs
import re
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
import torchaudio
import tomli
from cached_path import cached_path
from einops import rearrange
from vocos import Vocos
from transformers import pipeline
from F5_TTS.model import CFM, DiT, MMDiT, UNetT
from F5_TTS.model.utils import (convert_char_to_pinyin, get_tokenizer,
load_checkpoint, save_spectrogram)
from model import DiT, UNetT
from model.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
)
import os
import json
@ -31,32 +34,51 @@ ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
Args:
text (str): The text to be split.
max_chars (int): The maximum number of characters per chunk.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
for sentence in sentences:
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
if "voices" not in config:
voices = {"main": main_voice}
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
voices = config["voices"]
voices["main"] = main_voice
for voice in voices:
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
voices[voice]["ref_audio"], voices[voice]["ref_text"]
)
print("Voice:", voice)
print("Ref_audio:", voices[voice]["ref_audio"])
print("Ref_text:", voices[voice]["ref_text"])
if current_chunk:
chunks.append(current_chunk.strip())
generated_audio_segments = []
reg1 = r"(?=\[\w+\])"
chunks = re.split(reg1, text_gen)
reg2 = r"\[(\w+)\]"
for text in chunks:
match = re.match(reg2, text)
if match:
voice = match[1]
else:
print("No voice tag found, using main.")
voice = "main"
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = re.sub(reg2, "", text)
gen_text = text.strip()
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
print(f"Voice: {voice}")
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
generated_audio_segments.append(audio)
return chunks
if generated_audio_segments:
final_wave = np.concatenate(generated_audio_segments)
with open(wave_path, "wb") as f:
sf.write(f.name, final_wave, final_sample_rate)
# Remove silence
if remove_silence:
remove_silence_for_generated_wav(f.name)
print(f.name)
class F5TTS:
def __init__(self):
@ -65,6 +87,7 @@ class F5TTS:
self.remove_silence = config.remove_silence
self.modelname = config.modelname
self.ref_audio_fn = config.ref_audio_fn
self.load_vocoder_from_local = config.is_local or True
self.zmq_url = config.zmq_url
self.ref_text = config.ref_text
self.device = config.device
@ -72,18 +95,7 @@ class F5TTS:
self.gen_ref_audio()
self.gen_ref_text()
self.replier = ZmqReplier(self.zmq_url, self.generate)
try:
print(f"Load vocos from local path {vocos_local_path}")
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device)
vocos.load_state_dict(state_dict)
vocos.eval()
self.vocos = vocos
except:
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
self.vocos = vocos
self.vocos = load_vocoder(is_local=is_local, local_path="../checkpoints/charactr/vocos-mel-24khz")
self.F5TTS_model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
@ -123,93 +135,33 @@ class F5TTS:
ref_text += ". "
self.ref_text = ref_text
def _load_model(self, repo_name, exp_name, model_cls, model_cfg, ckpt_step):
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
if not Path(ckpt_path).exists():
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
model = CFM(
transformer=model_cls(
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
),
mel_spec_kwargs=dict(
target_sample_rate=self.sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
odeint_kwargs=dict(
method="euler"
),
vocab_char_map=vocab_char_map,
).to(self.device)
def load_model(self):
# load models
if self.modelname == "F5-TTS":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
if ckpt_file == "":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
model = load_checkpoint(model, ckpt_path, self.device, use_ema = True)
elif self.modelname == "E2-TTS":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if ckpt_file == "":
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
return model
def load_model(self, model):
if model == 'F5-TTS':
ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000)
return ret
if model == 'E2-TTS':
return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000)
self.model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
def split_text(self, text):
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
print('ref_text', ref_text)
def inference(self, prompt, stream=False):
generated_waves = []
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
for gen_text in gen_text_batches:
# Prepare the text
text_list = [self.ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
# Calculate duration
ref_audio_len = self.ref_audio.shape[-1] // hop_length
zh_pause_punc = r"。,、;:?!"
ref_text_len = len(self.ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, self.ref_text))
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
# inference
print(f'{self.device=}....')
with torch.inference_mode():
generated, _ = self.model.sample(
cond=self.ref_audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
generated_wave = vocos.decode(generated_mel_spec.cpu())
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
if stream:
yield genreated
else:
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
if stream:
print(f'here ........{stream}')
return
if self.cross_fade_duration <= 0:
# Simply concatenate
final_wave = np.concatenate(generated_waves)
else:
final_wave = self.cross_fade_wave(generated_waves)
fn = self.write_wave(final_wave)
print(f'here ........{stream}, {fn=}')
yield fn
return
def cross_fade_wave(self, waves):
final_wave = generated_waves[0]
for i in range(1, len(generated_waves)):

2
run.sh
View File

@ -1,4 +1,4 @@
#!/bin/sh
r=$HOME/ve/f5/bin/python
r=$HOME/ve/f5tts/bin/python
$r $*