tts/app/tts.py
2024-08-07 13:06:53 +08:00

68 lines
2.0 KiB
Python

import os, sys
import torch
import torchaudio
import argparse
from appPublic.log import MyLogger, info, debug, warning
from appPublic.folderUtils import ProgramPath, temp_file
from appPublic.jsonConfig import getConfig
from appPublic.registerfunction import RegisterFunction
from appPublic.worker import coroutinify
from ahserver.configuredServer import ConfiguredServer
from ahserver.globalEnv import get_definition
from ahserver.serverenv import ServerEnv
from io import BytesIO
import ChatTTS
__version__ = '0.0.1'
class TTS:
def __init__(self):
self.engine = ChatTTS.Chat()
custom_path=get_definition('chattts_model_path')
print(custom_path)
self.engine.load(source='custom',
custom_path=get_definition('chattts_model_path'),
compile=True)
def _generate(self, request, kw):
params_kw = kw.get('params_kw')
text = params_kw.get('prompt')
wavs = self.engine.infer(text)
fn = temp_file(suffix='.wav')
torchaudio.save(fn, torch.from_numpy(wavs[0]).unsqueeze(0), 24000, format='wav')
with open(fn, 'rb') as f:
b = f.read()
os.remove(fn)
return b
os.remove(fn)
generate = coroutinify(_generate)
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog="tts")
parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port')
args = parser.parse_args()
print(args)
workdir = args.workdir or os.getcwd()
p = ProgramPath()
config = getConfig(workdir, NS={'workdir':workdir, 'ProgramPath':p})
if config.logger:
logger = MyLogger(config.logger.name or 'tts',
levelname=config.logger.levelname or 'debug',
logfile=config.logger.logfile or None)
else:
logger = MyLogger('sage', levelname='debug')
info(f'========tts version={__version__}========')
# server = ConfiguredServer(auth_klass=MyAuthAPI, workdir=workdir)
server = ConfiguredServer(workdir=workdir)
tts = TTS()
rf = RegisterFunction()
rf.register('generate', tts.generate)
info(f'{rf.registKW=}')
port = args.port or config.website.port or 8080
port = int(port)
server.run(port=port)