66 lines
1.9 KiB
Python
66 lines
1.9 KiB
Python
import os
|
|
import asyncio
|
|
import soundfile as sf
|
|
import numpy as np
|
|
from traceback import format_exc
|
|
from dia.model import Dia
|
|
from appPublic.worker import awaitify
|
|
from appPublic.hf import hf_socks5proxy
|
|
from appPublic.folderUtils import _mkdir
|
|
from appPublic.uniqueID import getID
|
|
from appPublic.log import debug, exception, error
|
|
from appPublic.jsonConfig import getConfig
|
|
from ahserver.filestorage import FileStorage
|
|
from ahserver.webapp import webapp
|
|
from ahserver.serverenv import ServerEnv
|
|
import dac
|
|
hf_socks5proxy()
|
|
|
|
class DiaTTS:
|
|
def __init__(self):
|
|
|
|
config = getConfig()
|
|
# self.model = Dia.from_local(config.dia_model_path, compute_dtype="float16")
|
|
self.model = Dia.from_local(config.dia_model_path+'/config.json',
|
|
device=config.device or 'cpu',
|
|
load_dac=False,
|
|
checkpoint_path=config.dia_model_path + '/dia-v0_1.pth',
|
|
compute_dtype="float16")
|
|
self.load_dac(config.dac_model_path)
|
|
self.lock = asyncio.Lock()
|
|
self.fs = FileStorage()
|
|
|
|
def load_dac(self, dac_model_path):
|
|
dac_model = dac.DAC.load(dac_model_path).to(self.model.device)
|
|
dac_model.eval() # Ensure DAC is in eval mode
|
|
self.model.dac_model = dac_model
|
|
|
|
def _generate(self, prompt):
|
|
name = getID() + '.wav'
|
|
fp = self.fs._name2path(name)
|
|
webpath = self.fs.webpath(fp)
|
|
debug(f'{prompt=}')
|
|
output = self.model.generate(prompt, use_torch_compile=True, verbose=True)
|
|
output = output.astype(np.float32)
|
|
if output is None:
|
|
e = Exception(f'"{prompt}" to audio null')
|
|
exception(f'{e}\n{format_exc()}')
|
|
raise e
|
|
debug(f'{output.dtype.name=}')
|
|
_mkdir(os.path.dirname(fp))
|
|
sf.write(fp, output, 44100)
|
|
return webpath
|
|
|
|
async def generate(self, prompt):
|
|
async with self.lock:
|
|
f = awaitify(self._generate)
|
|
return await f(prompt)
|
|
|
|
def init():
|
|
g = ServerEnv()
|
|
g.etts_engine = DiaTTS()
|
|
|
|
if __name__ == '__main__':
|
|
webapp(init)
|
|
|