asr/app/whisper_model.py

55 lines
1.3 KiB
Python
Raw Normal View History

2024-08-03 17:27:14 +08:00
from ahserver.serverenv import ServerEnv
from appPublic.worker import awaitify
2024-08-03 18:26:44 +08:00
from appPublic.log import info, debug, warning, error, exception, critical
2024-08-03 17:57:04 +08:00
from ahserver.globalEnv import get_definition
2024-08-03 17:27:14 +08:00
import numpy as np
import base64
import whisper
# 编码
def base64_encode(text):
2024-08-03 17:35:06 +08:00
text_bytes = text.encode('utf-8')
encoded_bytes = base64.b64encode(text_bytes)
encoded_text = encoded_bytes.decode('utf-8')
return encoded_text
2024-08-03 17:27:14 +08:00
# 解码
def base64_decode(encoded_text):
2024-08-03 17:35:06 +08:00
encoded_bytes = encoded_text.encode('utf-8')
decoded_bytes = base64.b64decode(encoded_bytes)
decoded_text = decoded_bytes.decode('utf-8')
return decoded_text
2024-08-03 17:27:14 +08:00
class WhisperBase:
def __init__(self):
2024-08-03 17:35:06 +08:00
model_name = get_definition('whisper_model')
2024-08-03 17:27:14 +08:00
self.model = whisper.load_model(model_name)
def _stt(self, filepath):
2024-08-03 18:26:44 +08:00
e = Exception(f'{filepath=} WhisperBase can not use')
exception(f'{e=}')
raise e
2024-08-03 17:27:14 +08:00
class WhisperFile(WhisperBase):
2024-08-03 17:35:06 +08:00
def _stt(self, filepath):
return self.model.transcribe(filepath)
2024-08-03 17:27:14 +08:00
2024-08-05 17:00:39 +08:00
stt = awaitify(_stt)
2024-08-03 17:27:14 +08:00
class WhisperBase64(WhisperBase):
2024-08-06 11:31:01 +08:00
def _stt(self, audio):
audiolist = audio.values()
2024-08-06 12:04:38 +08:00
nparr = np.array(audiolist, dtype=np.float32)
nparr=np.vstack(nparr).astype(np.float)
2024-08-06 11:31:01 +08:00
"""
2024-08-04 18:59:28 +08:00
raw = base64_decode(audio_base64)
2024-08-03 17:35:06 +08:00
ndarr = np.frombuffer(raw, dtype=np.float32)
2024-08-06 11:31:01 +08:00
"""
info(f'ndarr={nparr}')
return self.model.transcribe(nparr)
2024-08-03 17:27:14 +08:00
2024-08-05 17:00:39 +08:00
stt = awaitify(_stt)