This commit is contained in:
yumoqing 2024-08-07 13:06:53 +08:00
parent 8beb2dca7a
commit 329d7d0545
2 changed files with 20 additions and 12 deletions

View File

@ -1,11 +1,14 @@
import os, sys
import torch
import torchaudio
import argparse
from appPublic.log import MyLogger, info, debug, warning
from appPublic.folderUtils import ProgramPath
from appPublic.folderUtils import ProgramPath, temp_file
from appPublic.jsonConfig import getConfig
from appPublic.registerfunction import RegisterFunction
from appPublic.worker import awaitify
from appPublic.worker import coroutinify
from ahserver.configuredServer import ConfiguredServer
from ahserver.globalEnv import get_definition
from ahserver.serverenv import ServerEnv
from io import BytesIO
@ -15,20 +18,25 @@ __version__ = '0.0.1'
class TTS:
def __init__(self):
self.engine = ChatTTS.Chat()
custom_path=get_definition('chattts_model_path')
print(custom_path)
self.engine.load(source='custom',
custom_path=get_definition(custom_path='chattts_model_path',
custom_path=get_definition('chattts_model_path'),
compile=True)
def _generate(request, **kw):
def _generate(self, request, kw):
params_kw = kw.get('params_kw')
text = params_kw.get('prompt')
wavs = self.engine.refer(text)
f = BytesIO()
torchaudio.save(f, touch.from_numpy(wavs[0]).unsqueeze(0), 24000)
f.seek(0,0)
return f.read()
generate = awaitify(_generate)
wavs = self.engine.infer(text)
fn = temp_file(suffix='.wav')
torchaudio.save(fn, torch.from_numpy(wavs[0]).unsqueeze(0), 24000, format='wav')
with open(fn, 'rb') as f:
b = f.read()
os.remove(fn)
return b
os.remove(fn)
generate = coroutinify(_generate)
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog="tts")

View File

@ -5,7 +5,7 @@
"logfile":"$[workdir]$/logs/tts.log"
},
"definitions":{
"chattts_model_path":"/d/ymq/osc/models/ChaTTTS"
"chattts_model_path":"/Users/ymq/models/ChaTTTS"
},
"filesroot":"$[workdir]$/files",
"website":{