68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
import os, sys
|
|
import torch
|
|
import torchaudio
|
|
import argparse
|
|
from appPublic.log import MyLogger, info, debug, warning
|
|
from appPublic.folderUtils import ProgramPath, temp_file
|
|
from appPublic.jsonConfig import getConfig
|
|
from appPublic.registerfunction import RegisterFunction
|
|
from appPublic.worker import coroutinify
|
|
from ahserver.configuredServer import ConfiguredServer
|
|
from ahserver.globalEnv import get_definition
|
|
from ahserver.serverenv import ServerEnv
|
|
from io import BytesIO
|
|
|
|
import ChatTTS
|
|
__version__ = '0.0.1'
|
|
|
|
class TTS:
|
|
def __init__(self):
|
|
self.engine = ChatTTS.Chat()
|
|
custom_path=get_definition('chattts_model_path')
|
|
print(custom_path)
|
|
self.engine.load(source='custom',
|
|
custom_path=get_definition('chattts_model_path'),
|
|
compile=True)
|
|
|
|
def _generate(self, request, kw):
|
|
params_kw = kw.get('params_kw')
|
|
text = params_kw.get('prompt')
|
|
wavs = self.engine.infer(text)
|
|
fn = temp_file(suffix='.wav')
|
|
torchaudio.save(fn, torch.from_numpy(wavs[0]).unsqueeze(0), 24000, format='wav')
|
|
with open(fn, 'rb') as f:
|
|
b = f.read()
|
|
os.remove(fn)
|
|
return b
|
|
os.remove(fn)
|
|
|
|
generate = coroutinify(_generate)
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(prog="tts")
|
|
parser.add_argument('-w', '--workdir')
|
|
parser.add_argument('-p', '--port')
|
|
args = parser.parse_args()
|
|
print(args)
|
|
workdir = args.workdir or os.getcwd()
|
|
p = ProgramPath()
|
|
config = getConfig(workdir, NS={'workdir':workdir, 'ProgramPath':p})
|
|
if config.logger:
|
|
logger = MyLogger(config.logger.name or 'tts',
|
|
levelname=config.logger.levelname or 'debug',
|
|
logfile=config.logger.logfile or None)
|
|
else:
|
|
logger = MyLogger('sage', levelname='debug')
|
|
|
|
info(f'========tts version={__version__}========')
|
|
# server = ConfiguredServer(auth_klass=MyAuthAPI, workdir=workdir)
|
|
server = ConfiguredServer(workdir=workdir)
|
|
tts = TTS()
|
|
rf = RegisterFunction()
|
|
rf.register('generate', tts.generate)
|
|
info(f'{rf.registKW=}')
|
|
port = args.port or config.website.port or 8080
|
|
port = int(port)
|
|
server.run(port=port)
|
|
|