import os, sys import torch import torchaudio import argparse from appPublic.log import MyLogger, info, debug, warning from appPublic.folderUtils import ProgramPath, temp_file from appPublic.jsonConfig import getConfig from appPublic.registerfunction import RegisterFunction from appPublic.worker import coroutinify from ahserver.configuredServer import ConfiguredServer from ahserver.globalEnv import get_definition from ahserver.serverenv import ServerEnv from io import BytesIO import ChatTTS __version__ = '0.0.1' class TTS: def __init__(self): self.engine = ChatTTS.Chat() custom_path=get_definition('chattts_model_path') print(custom_path) self.engine.load(source='custom', custom_path=get_definition('chattts_model_path'), compile=True) def _generate(self, request, kw): params_kw = kw.get('params_kw') text = params_kw.get('prompt') wavs = self.engine.infer(text) fn = temp_file(suffix='.wav') torchaudio.save(fn, torch.from_numpy(wavs[0]).unsqueeze(0), 24000, format='wav') with open(fn, 'rb') as f: b = f.read() os.remove(fn) return b os.remove(fn) generate = coroutinify(_generate) if __name__ == '__main__': parser = argparse.ArgumentParser(prog="tts") parser.add_argument('-w', '--workdir') parser.add_argument('-p', '--port') args = parser.parse_args() print(args) workdir = args.workdir or os.getcwd() p = ProgramPath() config = getConfig(workdir, NS={'workdir':workdir, 'ProgramPath':p}) if config.logger: logger = MyLogger(config.logger.name or 'tts', levelname=config.logger.levelname or 'debug', logfile=config.logger.logfile or None) else: logger = MyLogger('sage', levelname='debug') info(f'========tts version={__version__}========') # server = ConfiguredServer(auth_klass=MyAuthAPI, workdir=workdir) server = ConfiguredServer(workdir=workdir) tts = TTS() rf = RegisterFunction() rf.register('generate', tts.generate) info(f'{rf.registKW=}') port = args.port or config.website.port or 8080 port = int(port) server.run(port=port)