import os import asyncio import soundfile as sf import numpy as np from traceback import format_exc from dia.model import Dia from appPublic.worker import awaitify from appPublic.hf import hf_socks5proxy from appPublic.folderUtils import _mkdir from appPublic.uniqueID import getID from appPublic.log import debug, exception, error from appPublic.jsonConfig import getConfig from ahserver.filestorage import FileStorage from ahserver.webapp import webapp from ahserver.serverenv import ServerEnv import dac hf_socks5proxy() class DiaTTS: def __init__(self): config = getConfig() # self.model = Dia.from_local(config.dia_model_path, compute_dtype="float16") self.model = Dia.from_local(config.dia_model_path+'/config.json', device=config.device or 'cpu', load_dac=False, checkpoint_path=config.dia_model_path + '/dia-v0_1.pth', compute_dtype="float16") self.load_dac(config.dac_model_path) self.lock = asyncio.Lock() self.fs = FileStorage() def load_dac(self, dac_model_path): dac_model = dac.DAC.load(dac_model_path).to(self.model.device) dac_model.eval() # Ensure DAC is in eval mode self.model.dac_model = dac_model def _generate(self, prompt): name = getID() + '.wav' fp = self.fs._name2path(name) webpath = self.fs.webpath(fp) debug(f'{prompt=}') output = self.model.generate(prompt, use_torch_compile=True, verbose=True) output = output.astype(np.float32) if output is None: e = Exception(f'"{prompt}" to audio null') exception(f'{e}\n{format_exc()}') raise e debug(f'{output.dtype.name=}') _mkdir(os.path.dirname(fp)) sf.write(fp, output, 44100) return webpath async def generate(self, prompt): async with self.lock: f = awaitify(self._generate) return await f(prompt) def init(): g = ServerEnv() g.etts_engine = DiaTTS() if __name__ == '__main__': webapp(init)