bugfix
This commit is contained in:
parent
d8de885311
commit
e80c8a3321
202
f5tts.py
202
f5tts.py
@ -1,20 +1,23 @@
|
|||||||
from time import time
|
import sys
|
||||||
import torch
|
sys.path.append('./F5TTS')
|
||||||
from pathlib import Path
|
import argparse
|
||||||
import codecs
|
import codecs
|
||||||
import re
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import torch
|
import tomli
|
||||||
import torchaudio
|
|
||||||
from cached_path import cached_path
|
from cached_path import cached_path
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
from vocos import Vocos
|
from model import DiT, UNetT
|
||||||
from transformers import pipeline
|
from model.utils_infer import (
|
||||||
from F5_TTS.model import CFM, DiT, MMDiT, UNetT
|
load_vocoder,
|
||||||
from F5_TTS.model.utils import (convert_char_to_pinyin, get_tokenizer,
|
load_model,
|
||||||
load_checkpoint, save_spectrogram)
|
preprocess_ref_audio_text,
|
||||||
|
infer_process,
|
||||||
|
remove_silence_for_generated_wav,
|
||||||
|
)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
@ -31,32 +34,51 @@ ode_method = "euler"
|
|||||||
sway_sampling_coef = -1.0
|
sway_sampling_coef = -1.0
|
||||||
speed = 1.0
|
speed = 1.0
|
||||||
|
|
||||||
def chunk_text(text, max_chars=135):
|
def main_process(ref_audio, ref_text, text_gen, model_obj, remove_silence):
|
||||||
"""
|
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
||||||
Splits the input text into chunks, each with a maximum number of characters.
|
if "voices" not in config:
|
||||||
Args:
|
voices = {"main": main_voice}
|
||||||
text (str): The text to be split.
|
|
||||||
max_chars (int): The maximum number of characters per chunk.
|
|
||||||
Returns:
|
|
||||||
List[str]: A list of text chunks.
|
|
||||||
"""
|
|
||||||
chunks = []
|
|
||||||
current_chunk = ""
|
|
||||||
# Split the text into sentences based on punctuation followed by whitespace
|
|
||||||
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
|
|
||||||
|
|
||||||
for sentence in sentences:
|
|
||||||
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
|
||||||
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
|
||||||
else:
|
else:
|
||||||
if current_chunk:
|
voices = config["voices"]
|
||||||
chunks.append(current_chunk.strip())
|
voices["main"] = main_voice
|
||||||
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
for voice in voices:
|
||||||
|
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
||||||
|
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
||||||
|
)
|
||||||
|
print("Voice:", voice)
|
||||||
|
print("Ref_audio:", voices[voice]["ref_audio"])
|
||||||
|
print("Ref_text:", voices[voice]["ref_text"])
|
||||||
|
|
||||||
if current_chunk:
|
generated_audio_segments = []
|
||||||
chunks.append(current_chunk.strip())
|
reg1 = r"(?=\[\w+\])"
|
||||||
|
chunks = re.split(reg1, text_gen)
|
||||||
|
reg2 = r"\[(\w+)\]"
|
||||||
|
for text in chunks:
|
||||||
|
match = re.match(reg2, text)
|
||||||
|
if match:
|
||||||
|
voice = match[1]
|
||||||
|
else:
|
||||||
|
print("No voice tag found, using main.")
|
||||||
|
voice = "main"
|
||||||
|
if voice not in voices:
|
||||||
|
print(f"Voice {voice} not found, using main.")
|
||||||
|
voice = "main"
|
||||||
|
text = re.sub(reg2, "", text)
|
||||||
|
gen_text = text.strip()
|
||||||
|
ref_audio = voices[voice]["ref_audio"]
|
||||||
|
ref_text = voices[voice]["ref_text"]
|
||||||
|
print(f"Voice: {voice}")
|
||||||
|
audio, final_sample_rate, spectragram = infer_process(ref_audio, ref_text, gen_text, model_obj)
|
||||||
|
generated_audio_segments.append(audio)
|
||||||
|
|
||||||
return chunks
|
if generated_audio_segments:
|
||||||
|
final_wave = np.concatenate(generated_audio_segments)
|
||||||
|
with open(wave_path, "wb") as f:
|
||||||
|
sf.write(f.name, final_wave, final_sample_rate)
|
||||||
|
# Remove silence
|
||||||
|
if remove_silence:
|
||||||
|
remove_silence_for_generated_wav(f.name)
|
||||||
|
print(f.name)
|
||||||
|
|
||||||
class F5TTS:
|
class F5TTS:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -65,6 +87,7 @@ class F5TTS:
|
|||||||
self.remove_silence = config.remove_silence
|
self.remove_silence = config.remove_silence
|
||||||
self.modelname = config.modelname
|
self.modelname = config.modelname
|
||||||
self.ref_audio_fn = config.ref_audio_fn
|
self.ref_audio_fn = config.ref_audio_fn
|
||||||
|
self.load_vocoder_from_local = config.is_local or True
|
||||||
self.zmq_url = config.zmq_url
|
self.zmq_url = config.zmq_url
|
||||||
self.ref_text = config.ref_text
|
self.ref_text = config.ref_text
|
||||||
self.device = config.device
|
self.device = config.device
|
||||||
@ -72,18 +95,7 @@ class F5TTS:
|
|||||||
self.gen_ref_audio()
|
self.gen_ref_audio()
|
||||||
self.gen_ref_text()
|
self.gen_ref_text()
|
||||||
self.replier = ZmqReplier(self.zmq_url, self.generate)
|
self.replier = ZmqReplier(self.zmq_url, self.generate)
|
||||||
try:
|
self.vocos = load_vocoder(is_local=is_local, local_path="../checkpoints/charactr/vocos-mel-24khz")
|
||||||
print(f"Load vocos from local path {vocos_local_path}")
|
|
||||||
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
|
||||||
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=self.device)
|
|
||||||
vocos.load_state_dict(state_dict)
|
|
||||||
vocos.eval()
|
|
||||||
self.vocos = vocos
|
|
||||||
except:
|
|
||||||
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
|
|
||||||
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
|
||||||
self.vocos = vocos
|
|
||||||
|
|
||||||
self.F5TTS_model_cfg = dict(
|
self.F5TTS_model_cfg = dict(
|
||||||
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
||||||
)
|
)
|
||||||
@ -123,93 +135,33 @@ class F5TTS:
|
|||||||
ref_text += ". "
|
ref_text += ". "
|
||||||
self.ref_text = ref_text
|
self.ref_text = ref_text
|
||||||
|
|
||||||
def _load_model(self, repo_name, exp_name, model_cls, model_cfg, ckpt_step):
|
def load_model(self):
|
||||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
|
# load models
|
||||||
if not Path(ckpt_path).exists():
|
if self.modelname == "F5-TTS":
|
||||||
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
model_cls = DiT
|
||||||
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
|
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||||
model = CFM(
|
if ckpt_file == "":
|
||||||
transformer=model_cls(
|
repo_name = "F5-TTS"
|
||||||
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
exp_name = "F5TTS_Base"
|
||||||
),
|
ckpt_step = 1200000
|
||||||
mel_spec_kwargs=dict(
|
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||||
target_sample_rate=self.sample_rate,
|
|
||||||
n_mel_channels=n_mel_channels,
|
|
||||||
hop_length=hop_length,
|
|
||||||
),
|
|
||||||
odeint_kwargs=dict(
|
|
||||||
method="euler"
|
|
||||||
),
|
|
||||||
vocab_char_map=vocab_char_map,
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
model = load_checkpoint(model, ckpt_path, self.device, use_ema = True)
|
elif self.modelname == "E2-TTS":
|
||||||
|
model_cls = UNetT
|
||||||
|
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||||
|
if ckpt_file == "":
|
||||||
|
repo_name = "E2-TTS"
|
||||||
|
exp_name = "E2TTS_Base"
|
||||||
|
ckpt_step = 1200000
|
||||||
|
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||||
|
|
||||||
return model
|
self.model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
|
||||||
|
|
||||||
def load_model(self, model):
|
|
||||||
if model == 'F5-TTS':
|
|
||||||
ret = self._load_model(model, "F5TTS_Base", DiT, self.F5TTS_model_cfg, 1200000)
|
|
||||||
return ret
|
|
||||||
if model == 'E2-TTS':
|
|
||||||
return self._load_model(model, "E2TTS_Base", UNetT, self.E2TTS_model_cfg, 1200000)
|
|
||||||
|
|
||||||
def split_text(self, text):
|
def split_text(self, text):
|
||||||
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
||||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||||
print('ref_text', ref_text)
|
print('ref_text', ref_text)
|
||||||
|
|
||||||
def inference(self, prompt, stream=False):
|
|
||||||
generated_waves = []
|
|
||||||
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
|
|
||||||
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
|
|
||||||
for gen_text in gen_text_batches:
|
|
||||||
# Prepare the text
|
|
||||||
text_list = [self.ref_text + gen_text]
|
|
||||||
final_text_list = convert_char_to_pinyin(text_list)
|
|
||||||
|
|
||||||
# Calculate duration
|
|
||||||
ref_audio_len = self.ref_audio.shape[-1] // hop_length
|
|
||||||
zh_pause_punc = r"。,、;:?!"
|
|
||||||
ref_text_len = len(self.ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, self.ref_text))
|
|
||||||
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
|
|
||||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
|
||||||
|
|
||||||
# inference
|
|
||||||
print(f'{self.device=}....')
|
|
||||||
with torch.inference_mode():
|
|
||||||
generated, _ = self.model.sample(
|
|
||||||
cond=self.ref_audio,
|
|
||||||
text=final_text_list,
|
|
||||||
duration=duration,
|
|
||||||
steps=nfe_step,
|
|
||||||
cfg_strength=cfg_strength,
|
|
||||||
sway_sampling_coef=sway_sampling_coef,
|
|
||||||
)
|
|
||||||
generated = generated[:, ref_audio_len:, :]
|
|
||||||
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
|
||||||
generated_wave = vocos.decode(generated_mel_spec.cpu())
|
|
||||||
if rms < target_rms:
|
|
||||||
generated_wave = generated_wave * rms / target_rms
|
|
||||||
if stream:
|
|
||||||
yield genreated
|
|
||||||
else:
|
|
||||||
# wav -> numpy
|
|
||||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
|
||||||
generated_waves.append(generated_wave)
|
|
||||||
if stream:
|
|
||||||
print(f'here ........{stream}')
|
|
||||||
return
|
|
||||||
if self.cross_fade_duration <= 0:
|
|
||||||
# Simply concatenate
|
|
||||||
final_wave = np.concatenate(generated_waves)
|
|
||||||
else:
|
|
||||||
final_wave = self.cross_fade_wave(generated_waves)
|
|
||||||
fn = self.write_wave(final_wave)
|
|
||||||
print(f'here ........{stream}, {fn=}')
|
|
||||||
yield fn
|
|
||||||
return
|
|
||||||
|
|
||||||
def cross_fade_wave(self, waves):
|
def cross_fade_wave(self, waves):
|
||||||
final_wave = generated_waves[0]
|
final_wave = generated_waves[0]
|
||||||
for i in range(1, len(generated_waves)):
|
for i in range(1, len(generated_waves)):
|
||||||
|
Loading…
Reference in New Issue
Block a user