import os
import sys
import argparse
import codecs
import re
from pathlib import Path
from functools import partial

import numpy as np
import soundfile as sf
# import tomli
from cached_path import cached_path
import pycld2 as cld
import cn2an

from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
	mel_spec_type,
	target_rms,
	cross_fade_duration,
	nfe_step,
	cfg_strength,
	sway_sampling_coef,
	speed,
	fix_duration,
	infer_process,
	load_model,
	load_vocoder,
	preprocess_ref_audio_text,
	remove_silence_for_generated_wav,
)

import json
from time import time, sleep
from appPublic.dictObject import DictObject
from appPublic.folderUtils import temp_file
from appPublic.jsonConfig import getConfig
from appPublic.worker import awaitify
from appPublic.uniqueID import getID
from appPublic.log import debug, info
from appPublic.background import Background
from ahserver.webapp import webapp
from ahserver.serverenv import ServerEnv
from ahserver.filestorage import FileStorage

n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32  # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0

def write_wav_buffer(wav, nchannels, framerate):
	fs = FileStorage()
	fn = fs._name2path(f'{getID()}.wav', userid='tmp')
	os.makedirs(os.path.dirname(fn))
	debug(fn)
	with open(fn, "wb") as f:
		sf.write(f.name, wav, framerate)
	return fs.webpath(fn)

async_write_wav_buffer = awaitify(write_wav_buffer)

def detect_language(txt):
	isReliable, textBytesFound, details = cld.detect(txt)
	debug(f' detect_language():{isReliable=}, {textBytesFound=}, {details=} ')
	return details[0][1]

class F5TTS:
	def __init__(self):
		self.config = getConfig()
		# self.vocos = load_vocoder(is_local=True, local_path="../checkpoints/charactr/vocos-mel-24khz")
		self.load_model()
		self.setup_voices()

	def load_model(self):
		self.vocoder = load_vocoder(vocoder_name=self.config.vocoder_name, 
							is_local=True, 
							local_path=self.config.vocoder_local_path)

		# load models
		ckpt_file = ''
		if self.config.modelname == "F5-TTS":
			model_cls = DiT
			model_cfg = dict(dim=1024, depth=22, heads=16, 
								ff_mult=2, text_dim=512, conv_layers=4)
			if self.config.vocoder_name == "vocos":
				repo_name = "F5-TTS"
				exp_name = "F5TTS_Base"
				ckpt_step = 1200000
				ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
				# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt"  # .pt | .safetensors; local path
			elif self.config.vocoder_name == "bigvgan":
				repo_name = "F5-TTS"
				exp_name = "F5TTS_Base_bigvgan"
				ckpt_step = 1250000
				ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))

		elif self.config.modelname == "E2-TTS":
			model_cls = UNetT
			model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
			if ckpt_file == "":
				repo_name = "E2-TTS"
				exp_name = "E2TTS_Base"
				ckpt_step = 1200000
				ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))

		self.model = load_model(model_cls, model_cfg, ckpt_file, 
						mel_spec_type=self.config.vocoder_name,
						vocab_file=self.config.vocab_file)
		self.model = self.model.to(self.config.device)

	def f5tts_infer(self, ref_audio, ref_text, gen_text):
			audio, final_sample_rate, spectragram = \
					infer_process(ref_audio, 
									ref_text, 
									gen_text, 
									self.model,
									self.vocoder,
									mel_spec_type=self.config.vocoder_name,
									speed=self.config.speed or speed)
			return {
				'audio':audio,
				'sample_rate':final_sample_rate,
				'spectragram':spectragram
			}
		
	def get_speakers(self):
		t = [{'value':s, 'text':s} for s in self.speakers.keys() ]
		t.append({'value':'main', 'text':'main'})
		return t

	async def split_text(self, text_gen, speaker):
		reg1 = r"(?=\[\w+\])"
		lang = await awaitify(detect_language)(text_gen)
		if self.config.language.get(lang):
			reg1 = r"{}".format(self.config.language.get(lang).sentence_splitter)
		if lang == 'zh':
			text_gen = await awaitify(cn2an.transform)(text_gen, 'an2cn')

		chunks = re.split(reg1, text_gen)
		# reg2 = self.config.speaker_match
		reg2 = r"\[\[(\w+)\]\]"
		ret = []
		for text in chunks:
			if text == ['\r', '']:
				continue
			voice = speaker
			match = re.match(reg2, text)
			if match:
				debug(f'{text=}, match {reg2=}')
				voice = match[1]
			if voice not in self.voices:
				voice = "main"
			debug(f'{text} inferences with speaker({voice})..{reg2=}')
			text = re.sub(reg2, "", text)
			gen_text = text.strip()
			ref_audio = self.voices[voice]["ref_audio"]
			ref_text = self.voices[voice]["ref_text"]
			ret.append({'text':gen_text, 'ref_audio':ref_audio, 'ref_text':ref_text})
		return ret

	async def infer_stream(self, prompt, speaker):
		async for a in self._inference_stream(prompt, speaker):
			wavdata = a['audio']
			samplerate = a['sample_rate']
			b = await async_write_wav_buffer(wavdata, 1, samplerate)
			yield b

	async def _inference_stream(self, prompt, speaker):
		text_gen = prompt
		chunks = await self.split_text(prompt, speaker)
		for chunk in chunks:
			gen_text = chunk['text']
			ref_audio = chunk['ref_audio']
			ref_text = chunk['ref_text']
			infer = awaitify(self.f5tts_infer)
			try:
				d = await infer(ref_audio, ref_text, gen_text)
				yield d
			except:
				pass

	def setup_voices(self):
		config = getConfig()
		d = None
		with codecs.open(config.speakers_file, 'r', 'utf-8') as f:
			b = f.read()
			self.speakers = json.loads(b)
		ref_audio, ref_text = preprocess_ref_audio_text(config.ref_audio, config.ref_text)
		self.voices = {
			"main":{
				'ref_text':ref_text,
				'ref_audio':ref_audio
			}
		}
		for k,v in self.speakers.items():
			ref_audio, ref_text = preprocess_ref_audio_text(v['ref_audio'], v['ref_text'])
			self.voices[k] = {
				'ref_text':ref_text,
				'ref_audio':ref_audio
			}

	async def add_voice(self, speaker, ref_audio, ref_text):
		debug(f'{speaker=}, {ref_audio=}, {ref_text=}');
		config = getConfig()
		ref_audio = FileStorage().realPath(ref_audio)
		self.speakers[speaker] = {
			'ref_text':ref_text,
			'ref_audio':ref_audio
		}
		f = awaitify(preprocess_ref_audio_text)
		ref_audio, ref_text = await f(ref_audio, ref_text)
		self.voices[speaker] = {
			'ref_text':ref_text,
			'ref_audio':ref_audio
		}
		with codecs.open(config.speakers_file, 'w', 'utf-8') as f:
			f.write(json.dumps(self.speakers, indent=4))
		return None

	async def _inference(self, prompt, speaker):
		generated_audio_segments = []
		remove_silence = self.config.remove_silence or False
		final_sample_rate = 24000
		async for d in self._inference_stream(prompt, speaker):
			audio = d['audio'] 
			final_sample_rate = d['sample_rate']
			generated_audio_segments.append(audio)

		if generated_audio_segments:
			final_wave = np.concatenate(generated_audio_segments)
			return await async_write_wav_buffer(final_wave, 1, final_sample_rate)

def UiError(title="出错", message="出错啦", timeout=5):
	return {
		"widgettype":"Error",
		"options":{
			"author":"tr",
			"timeout":timeout,
			"cwidth":15,
			"cheight":10,
			"title":title,
			"auto_open":True,
			"auto_dismiss":True,
			"auto_destroy":True,
			"message":message
		}
	}


def UiMessage(title="消息", message="后台消息", timeout=5):
	return {
		"widgettype":"Message",
		"options":{
			"author":"tr",
			"timeout":timeout,
			"cwidth":15,
			"cheight":10,
			"title":title,
			"auto_open":True,
			"auto_dismiss":True,
			"auto_destroy":True,
			"message":message
		}
	}

def test1():
	sleep(36000)
	return {}

def init():
	g = ServerEnv()
	f5 = F5TTS()
	g.infer_stream = f5.infer_stream
	g.get_speakers = f5.get_speakers
	g.infer = f5._inference
	g.test1 = awaitify(test1)
	g.add_voice = f5.add_voice
	g.UiError = UiError
	g.UiMessage = UiMessage

if __name__ == '__main__':
	webapp(init)