This commit is contained in:
yumoqing 2024-08-03 17:35:06 +08:00
parent 07627e0648
commit 642e2f1ab7
2 changed files with 16 additions and 20 deletions

View File

@ -1,4 +0,0 @@
path = params_kw.get('audiofile')
print(f'whisper.generate.dspy, {params_kw=}, {file_realpath=}, {stt_engine=}' )
x = await stt_engine.stt(file_realpath(path))
return x

View File

@ -7,38 +7,38 @@ import whisper
# 编码 # 编码
def base64_encode(text): def base64_encode(text):
text_bytes = text.encode('utf-8') text_bytes = text.encode('utf-8')
encoded_bytes = base64.b64encode(text_bytes) encoded_bytes = base64.b64encode(text_bytes)
encoded_text = encoded_bytes.decode('utf-8') encoded_text = encoded_bytes.decode('utf-8')
return encoded_text return encoded_text
# 解码 # 解码
def base64_decode(encoded_text): def base64_decode(encoded_text):
encoded_bytes = encoded_text.encode('utf-8') encoded_bytes = encoded_text.encode('utf-8')
decoded_bytes = base64.b64decode(encoded_bytes) decoded_bytes = base64.b64decode(encoded_bytes)
decoded_text = decoded_bytes.decode('utf-8') decoded_text = decoded_bytes.decode('utf-8')
return decoded_text return decoded_text
class WhisperBase: class WhisperBase:
def __init__(self): def __init__(self):
model_name = get_definition('whisper_model') model_name = get_definition('whisper_model')
self.model = whisper.load_model(model_name) self.model = whisper.load_model(model_name)
def _stt(self, filepath): def _stt(self, filepath):
pass pass
stt = awaitify(_stt) stt = awaitify(_stt)
class WhisperFile(WhisperBase): class WhisperFile(WhisperBase):
def _stt(self, filepath): def _stt(self, filepath):
return self.model.transcribe(filepath) return self.model.transcribe(filepath)
class WhisperBase64(WhisperBase): class WhisperBase64(WhisperBase):
def _stt(self, audio_base64): def _stt(self, audio_base64):
raw = base64.decode(audio_base64) raw = base64.decode(audio_base64)
ndarr = np.frombuffer(raw, dtype=np.float32) ndarr = np.frombuffer(raw, dtype=np.float32)
return self.model.transcribe(raw) return self.model.transcribe(raw)
g = ServerEnv() g = ServerEnv()
g.whisper_engine = WhisperBase64() g.whisper_engine = WhisperBase64()