first commit

This commit is contained in:
yumoqing 2024-10-21 18:29:57 +08:00
commit 60fab4e937
4 changed files with 288 additions and 0 deletions

10
conf/config.json Normal file
View File

@ -0,0 +1,10 @@
{
"zmq_url" : "tcp://127.0.0.1:10003",
"sample_rate":16000,
"remove_silence":false,
"modelname":"F5-TTS",
"ref_audio_fn":"$[workdir]$/samples/test_zh_1_ref_short.wav",
"ref_text":"对,这就是我,万人敬仰的太乙真人。",
"cross_fade_duration":0
}

272
f5tts.py Normal file
View File

@ -0,0 +1,272 @@
import time
from pathlib import Path
import codecs
import re
import numpy as np
import soundfile as sf
import torch
import torchaudio
from cached_path import cached_path
from einops import rearrange
from vocos import Vocos
from transformers import pipeline
from F5_TTS.model import CFM, DiT, MMDiT, UNetT
from F5_TTS.model.utils import (convert_char_to_pinyin, get_tokenizer,
load_checkpoint, save_spectrogram)
import os
import json
from appPublic.dictObject import DictObject
from appPublic.zmq_reqrep import ZmqReplier
from appPublic.jsonConfig import getConfig
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
F5TTS_model_cfg = dict(
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
Args:
text (str): The text to be split.
max_chars (int): The maximum number of characters per chunk.
Returns:
List[str]: A list of text chunks.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
for sentence in sentences:
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
class F5TTS:
def __init__(self):
config = getConfig()
self.sample_rate = config.sample_rate
self.remove_silence = config.remove_silence
self.modelname = config.modelname
self.ref_audio_fn = config.ref_audio_fn
self.ref_text = config.ref_text
self.model= self.load_model(self.modelname)
self.cross_fade_duration = config.cross_fade_duration
self.gen_ref_audio()
self.gen_ref_text()
self.replier = ZmqReplier(self.zmq_url, self.generate)
try:
print(f"Load vocos from local path {vocos_local_path}")
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
self.vocos = vocos
except:
print("Donwload Vocos from huggingface charactr/vocos-mel-24khz")
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
self.vocos = vocos
def gen_ref_audio(self):
"""
gen ref_audio
"""
audio, sr = torchaudio.load(ref_audio)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
self.ref_audio = audio
def run(self):
print(f'running {self.zmq_url}')
self.replier._run()
print('ended ...')
def gen_ref_text(self):
"""
"""
# Add the functionality to ensure it ends with ". "
ref_text = self.ref_text
if not ref_text.endswith(". ") and not ref_text.endswith(""):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += ". "
self.ref_text = ref_text
def _load_model(self, repo_name, exp_name, model_cls, model_cfg, ckpt_step):
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
if not Path(ckpt_path).exists():
ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
model = CFM(
transformer=model_cls(
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
),
mel_spec_kwargs=dict(
target_sample_rate=self.sample_rate,
n_mel_channels=n_mel_channels,
hop_length=hop_length,
),
odeint_kwargs=dict(
method="euler"
),
vocab_char_map=vocab_char_map,
).to(device)
model = load_checkpoint(model, ckpt_path, device, use_ema = True)
return model
def load_model(self, model):
if model == 'F5-TTS':
ret = self._load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
return ret
if model == 'E2-TTS':
return self._load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
def split_text(self, text):
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
print('ref_text', ref_text)
def inference(self, prmpt, stream=False):
generated_waves = []
max_chars = int(len(self.ref_text.encode('utf-8')) / (self.ref_audio.shape[-1] / self.sample_rate) * (25 - self.ref_audio.shape[-1] / self.sample_rate))
gen_text_batches = chunk_text(prompt, max_chars=max_chars)
for gen_text in gen_text_Batches:
# Prepare the text
text_list = [self.ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
# Calculate duration
ref_audio_len = self.ref_audio.shape[-1] // hop_length
zh_pause_punc = r"。,、;:?!"
ref_text_len = len(self.ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, self.ref_text))
gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
# inference
with torch.inference_mode():
generated, _ = self.model.sample(
cond=self.ref_audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
generated_wave = vocos.decode(generated_mel_spec.cpu())
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
if stream:
yield genreated
else:
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
if stream:
return
if self.cross_fade_duration <= 0:
# Simply concatenate
final_wave = np.concatenate(generated_waves)
else:
final_wave = self.cross_fade_wave(generated_waves)
fn = self.write_wave(final_wave)
return fn
def cross_fade_wave(self, waves):
final_wave = generated_waves[0]
for i in range(1, len(generated_waves)):
prev_wave = final_wave
next_wave = generated_waves[i]
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
cross_fade_samples = int(self.cross_fade_duration * self.sample_rate)
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
if cross_fade_samples <= 0:
# No overlap possible, concatenate
final_wave = np.concatenate([prev_wave, next_wave])
continue
# Overlapping parts
prev_overlap = prev_wave[-cross_fade_samples:]
next_overlap = next_wave[:cross_fade_samples]
# Fade out and fade in
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
# Cross-faded overlap
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
# Combine
new_wave = np.concatenate([
prev_wave[:-cross_fade_samples],
cross_faded_overlap,
next_wave[cross_fade_samples:]
])
final_wave = new_wave
return final_wave
def write_wave(wave):
fn = temp_file(suffix='.wav')
sf.write(fn, wave, self.sample_rate)
return fn
def generate(self, d):
msg= d.decode('utf-8')
data = DictObject(**json.loads(msg))
print(data)
t1 = time()
f = self.inference(data.prompt)
t2 = time()
d = {
"audio_file":f,
"time_cost":t2 - t1
}
print(f'{d}')
return json.dumps(d)
if __name__ == '__main__':
workdir = os.getcwd()
config = getConfig(workdir)
print(f'{config=}')
tts = F5TTS()
print('here')
tts.run()

2
requirements.txt Normal file
View File

@ -0,0 +1,2 @@
git+https://github.com/SWivid/F5-TTS.git
git+https://git.kaiyuancloud.cn/yumoqing/apppublic

4
run.sh Executable file
View File

@ -0,0 +1,4 @@
#!/bin/sh
r=$HOME/ve/f5/bin/python
$r $*