tts/app/tts.py
2024-08-06 15:11:07 +08:00

60 lines
1.8 KiB
Python

import os, sys
import argparse
from appPublic.log import MyLogger, info, debug, warning
from appPublic.folderUtils import ProgramPath
from appPublic.jsonConfig import getConfig
from appPublic.registerfunction import RegisterFunction
from appPublic.worker import awaitify
from ahserver.configuredServer import ConfiguredServer
from ahserver.serverenv import ServerEnv
from io import BytesIO
import ChatTTS
__version__ = '0.0.1'
class TTS:
def __init__(self):
self.engine = ChatTTS.Chat()
self.engine.load(source='custom',
custom_path=get_definition(custom_path='chattts_model_path',
compile=True)
def _generate(request, **kw):
params_kw = kw.get('params_kw')
text = params_kw.get('prompt')
wavs = self.engine.refer(text)
f = BytesIO()
torchaudio.save(f, touch.from_numpy(wavs[0]).unsqueeze(0), 24000)
f.seek(0,0)
return f.read()
generate = awaitify(_generate)
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog="tts")
parser.add_argument('-w', '--workdir')
parser.add_argument('-p', '--port')
args = parser.parse_args()
print(args)
workdir = args.workdir or os.getcwd()
p = ProgramPath()
config = getConfig(workdir, NS={'workdir':workdir, 'ProgramPath':p})
if config.logger:
logger = MyLogger(config.logger.name or 'tts',
levelname=config.logger.levelname or 'debug',
logfile=config.logger.logfile or None)
else:
logger = MyLogger('sage', levelname='debug')
info(f'========tts version={__version__}========')
# server = ConfiguredServer(auth_klass=MyAuthAPI, workdir=workdir)
server = ConfiguredServer(workdir=workdir)
tts = TTS()
rf = RegisterFunction()
rf.register('generate', tts.generate)
info(f'{rf.registKW=}')
port = args.port or config.website.port or 8080
port = int(port)
server.run(port=port)