bugfix
This commit is contained in:
parent
d8de885311
commit
e80c8a3321
204
f5tts.py
204
f5tts.py
@ -1,20 +1,23 @@
|
||||
from time import time
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import sys
|
||||
sys.path.append('./F5TTS')
|
||||
import argparse
|
||||
import codecs
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
import tomli
|
||||
from cached_path import cached_path
|
||||
from einops import rearrange
|
||||
|
||||
from vocos import Vocos
|
||||
from transformers import pipeline
|
||||
from F5_TTS.model import CFM, DiT, MMDiT, UNetT
|
||||
from F5_TTS.model.utils import (convert_char_to_pinyin, get_tokenizer,
|
||||
load_checkpoint, save_spectrogram)
|
||||
from model import DiT, UNetT
|
||||
from model.utils_infer import (
|
||||
load_vocoder,
|
||||
load_model,
|
||||
preprocess_ref_audio_text,
|
||||
infer_process,
|
||||
remove_silence_for_generated_wav,
|
||||
)
|
||||
|
||||
import os
|
||||
import json
|
||||
@ -31,32 +34,51 @@ ode_method = "euler"
|
||||
sway_sampling_coef = -1.0
|
||||
speed = 1.0
|
||||
|
||||
def chunk_text(text, max_chars=135):
|
||||
"""
|
||||
Splits the input text into chunks, each with a maximum number of characters.
|
||||
Args:
|
||||
text (str): The text to be split.
|
||||
max_chars (int): The maximum number of characters per chunk.
|
||||
Returns:
|
||||
List[str]: A list of text chunks.
|
||||
"""
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
# Split the text into sentences based on punctuation followed by whitespace
|
||||
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
|
||||
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
||||
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
||||
if "voices" not in config:
|
||||
voices = {"main": main_voice}
|
||||
else:
|
||||
voices = config["voices"]
|
||||
voices["main"] = main_voice
|
||||
for voice in voices:
|
||||
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
||||
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
||||
)
|
||||
print("Voice:", voice)
|
||||
print("Ref_audio:", voices[voice]["ref_audio"])
|
||||
print("Ref_text:", voices[voice]["ref_text"])
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
||||
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
||||
generated_audio_segments = []
|
||||
reg1 = r"(?=\[\w+\])"
|
||||
chunks = re.split(reg1, text_gen)
|
||||
reg2 = r"\[(\w+)\]"
|
||||
for text in chunks:
|
||||
match = re.match(reg2, text)
|
||||
if match:
|
||||
voice = match[1]
|
||||
else:
|
||||
print("No voice tag found, using main.")
|
||||
voice = "main"
|
||||
if voice not in voices:
|
||||
print(f"Voice {voice} not found, using main.")
|
||||
voice = "main"
|
||||
text = re.sub(reg2, "", text)
|
||||
gen_text = text.strip()
|
||||
ref_audio = voices[voice]["ref_audio"]
|
||||
ref_text = voices[voice]["ref_text"]
|
||||
print(f"Voice: {voice}")
|
||||
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
|
||||
generated_audio_segments.append(audio)
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
if generated_audio_segments:
|
||||
final_wave = np.concatenate(generated_audio_segments)
|
||||
with open(wave_path, "wb") as f:
|
||||
sf.write(f.name, final_wave, final_sample_rate)
|
||||
# Remove silence
|
||||
if remove_silence:
|
||||
remove_silence_for_generated_wav(f.name)
|
||||
print(f.name)
|
||||
|
||||
class F5TTS:
|
||||
def __init__(self):
|
||||
@ -65,6 +87,7 @@ class F5TTS:
|
||||
self.remove_silence = config.remove_silence
|
||||
self.modelname = config.modelname
|
||||
self.ref_audio_fn = config.ref_audio_fn
|
||||
self.load_vocoder_from_local = config.is_local or True
|
||||
self.zmq_url = config.zmq_url
|
||||
self.ref_text = config.ref_text
|
||||
self.device = config.device
|
||||
@ -72,18 +95,7 @@ class F5TTS:
|
||||
self.gen_ref_audio()
|
||||
self.gen_ref_text()
|
||||
self.replier = ZmqReplier(self.zmq_url, self.generate)
|
||||
try:
|
||||
print(f"Load vocos from local path {vocos_local_path}")
|
||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device)
|
||||
vocos.load_state_dict(state_dict)
|
||||
vocos.eval()
|
||||
self.vocos = vocos
|
||||
except:
|
||||
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
|
||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
self.vocos = vocos
|
||||
|
||||
self.vocos = load_vocoder(is_local=is_local, local_path="../checkpoints/charactr/vocos-mel-24khz")
|
||||
self.F5TTS_model_cfg = dict(
|
||||
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
||||
)
|
||||
@ -123,93 +135,33 @@ class F5TTS:
|
||||
ref_text += ". "
|
||||
self.ref_text = ref_text
|
||||
|
||||
def _load_model(self, repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
||||
if not Path(ckpt_path).exists():
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
||||
model = CFM(
|
||||
transformer=model_cls(
|
||||
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
||||
),
|
||||
mel_spec_kwargs=dict(
|
||||
target_sample_rate=self.sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method="euler"
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(self.device)
|
||||
def load_model(self):
|
||||
# load models
|
||||
if self.modelname == "F5-TTS":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
if ckpt_file == "":
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base"
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
|
||||
model = load_checkpoint(model, ckpt_path, self.device, use_ema = True)
|
||||
elif self.modelname == "E2-TTS":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
if ckpt_file == "":
|
||||
repo_name = "E2-TTS"
|
||||
exp_name = "E2TTS_Base"
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
|
||||
return model
|
||||
|
||||
def load_model(self, model):
|
||||
if model == 'F5-TTS':
|
||||
ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000)
|
||||
return ret
|
||||
if model == 'E2-TTS':
|
||||
return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000)
|
||||
self.model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
||||
|
||||
def split_text(self, text):
|
||||
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
print('ref_text', ref_text)
|
||||
|
||||
def inference(self, prompt, stream=False):
|
||||
generated_waves = []
|
||||
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
||||
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
|
||||
for gen_text in gen_text_batches:
|
||||
# Prepare the text
|
||||
text_list = [self.ref_text + gen_text]
|
||||
final_text_list = convert_char_to_pinyin(text_list)
|
||||
|
||||
# Calculate duration
|
||||
ref_audio_len = self.ref_audio.shape[-1] // hop_length
|
||||
zh_pause_punc = r"。,、;:?!"
|
||||
ref_text_len = len(self.ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, self.ref_text))
|
||||
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# inference
|
||||
print(f'{self.device=}....')
|
||||
with torch.inference_mode():
|
||||
generated, _ = self.model.sample(
|
||||
cond=self.ref_audio,
|
||||
text=final_text_list,
|
||||
duration=duration,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
if stream:
|
||||
yield genreated
|
||||
else:
|
||||
# wav -> numpy
|
||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||
generated_waves.append(generated_wave)
|
||||
if stream:
|
||||
print(f'here ........{stream}')
|
||||
return
|
||||
if self.cross_fade_duration <= 0:
|
||||
# Simply concatenate
|
||||
final_wave = np.concatenate(generated_waves)
|
||||
else:
|
||||
final_wave = self.cross_fade_wave(generated_waves)
|
||||
fn = self.write_wave(final_wave)
|
||||
print(f'here ........{stream}, {fn=}')
|
||||
yield fn
|
||||
return
|
||||
|
||||
def cross_fade_wave(self, waves):
|
||||
final_wave = generated_waves[0]
|
||||
for i in range(1, len(generated_waves)):
|
||||
|
Loading…
Reference in New Issue
Block a user