This commit is contained in:
yumoqing 2024-08-07 13:06:53 +08:00
parent 8beb2dca7a
commit 329d7d0545
2 changed files with 20 additions and 12 deletions

View File

@ -1,11 +1,14 @@
import os, sys import os, sys
import torch
import torchaudio
import argparse import argparse
from appPublic.log import MyLogger, info, debug, warning from appPublic.log import MyLogger, info, debug, warning
from appPublic.folderUtils import ProgramPath from appPublic.folderUtils import ProgramPath, temp_file
from appPublic.jsonConfig import getConfig from appPublic.jsonConfig import getConfig
from appPublic.registerfunction import RegisterFunction from appPublic.registerfunction import RegisterFunction
from appPublic.worker import awaitify from appPublic.worker import coroutinify
from ahserver.configuredServer import ConfiguredServer from ahserver.configuredServer import ConfiguredServer
from ahserver.globalEnv import get_definition
from ahserver.serverenv import ServerEnv from ahserver.serverenv import ServerEnv
from io import BytesIO from io import BytesIO
@ -15,20 +18,25 @@ __version__ = '0.0.1'
class TTS: class TTS:
def __init__(self): def __init__(self):
self.engine = ChatTTS.Chat() self.engine = ChatTTS.Chat()
custom_path=get_definition('chattts_model_path')
print(custom_path)
self.engine.load(source='custom', self.engine.load(source='custom',
custom_path=get_definition(custom_path='chattts_model_path', custom_path=get_definition('chattts_model_path'),
compile=True) compile=True)
def _generate(request, **kw): def _generate(self, request, kw):
params_kw = kw.get('params_kw') params_kw = kw.get('params_kw')
text = params_kw.get('prompt') text = params_kw.get('prompt')
wavs = self.engine.refer(text) wavs = self.engine.infer(text)
f = BytesIO() fn = temp_file(suffix='.wav')
torchaudio.save(f, touch.from_numpy(wavs[0]).unsqueeze(0), 24000) torchaudio.save(fn, torch.from_numpy(wavs[0]).unsqueeze(0), 24000, format='wav')
f.seek(0,0) with open(fn, 'rb') as f:
return f.read() b = f.read()
os.remove(fn)
generate = awaitify(_generate) return b
os.remove(fn)
generate = coroutinify(_generate)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(prog="tts") parser = argparse.ArgumentParser(prog="tts")

View File

@ -5,7 +5,7 @@
"logfile":"$[workdir]$/logs/tts.log" "logfile":"$[workdir]$/logs/tts.log"
}, },
"definitions":{ "definitions":{
"chattts_model_path":"/d/ymq/osc/models/ChaTTTS" "chattts_model_path":"/Users/ymq/models/ChaTTTS"
}, },
"filesroot":"$[workdir]$/files", "filesroot":"$[workdir]$/files",
"website":{ "website":{