This commit is contained in:
yumoqing 2024-08-09 15:43:28 +08:00
parent d1f3b10569
commit 504dfb3157

View File

@ -1,47 +1,25 @@
from traceback import print_exc from traceback import print_exc
import base64 import base64
import wave
from appPublic.log import info, debug, warning, error, exception, critical from appPublic.log import info, debug, warning, error, exception, critical
from appPublic.dictObject import DictObject from appPublic.dictObject import DictObject
from appPublic.folderUtils import temp_file from appPublic.folderUtils import temp_file
from ahserver.serverenv import ServerEnv from ahserver.serverenv import ServerEnv
from aiohttp.web import StreamResponse from aiohttp.web import StreamResponse
from io import BytesIO def save_base64_wav(base64_data, output_file,sample_rate=16000, num_channels=1):
import struct # Decode the base64 data
wav_data = base64.b64decode(base64_data)
def audio_dic2list(audio): # Open a new WAV file for writing
ks = [k for k in audio.keys()] with wave.open(output_file, 'wb') as wf:
info(f'{type(audio)}, {ks=}') # Set the parameters of the WAV file
ks.sort() wf.setnchannels(num_channels) # Mono channel
return [audio[k] for k in ks] wf.setsampwidth(2) # 16-bit sample width
wf.setframerate(sample_rate) # 44.1 kHz sample rate
def float32array_to_wav(samples, sample_rate=16000, num_channels=1): # Write the decoded data to the WAV file
# Calculate the total number of samples wf.writeframes(wav_data)
num_samples = len(samples)
# Calculate the byte rate
byte_rate = sample_rate * num_channels * 4
# Calculate the block align
block_align = num_channels * 4
# Create the WAV header
header = struct.pack(
'<4sI4s4sIHHIIHH4sI',
b'RIFF', 36 + num_samples * 4, b'WAVE', b'fmt ', 16, 3, num_channels, sample_rate,
byte_rate, block_align, 32, b'data', num_samples * 4
)
# info(f'float32array_to_wav({samples[:10]}, ...)')
# Convert the Float32Array to bytes
data = struct.pack('f' * num_samples, *samples)
# Write the header and data to a file
tmpfile = temp_file(suffix='.wav')
with open(tmpfile, 'wb') as f:
f.write(header)
f.write(data)
return tmpfile
async def generate(request, **kw): async def generate(request, **kw):
params_kw = kw.get('params_kw', DictObject()) params_kw = kw.get('params_kw', DictObject())
@ -52,6 +30,8 @@ async def generate(request, **kw):
'status':'error', 'status':'error',
'message':'audio is null' 'message':'audio is null'
} }
fname = temp_file(suffix='.wav')
save_base64_wav(audio, fname)
engine = None engine = None
g = ServerEnv() g = ServerEnv()
if model=='whisper': if model=='whisper':
@ -63,8 +43,6 @@ async def generate(request, **kw):
'message':f'model={model} is not defined' 'message':f'model={model} is not defined'
} }
try: try:
audio = audio_dic2list(audio)
fname = float32array_to_wav(audio)
txt = await engine.stt(fname) txt = await engine.stt(fname)
os.remove(fname) os.remove(fname)
info(f'{txt=}') info(f'{txt=}')