This commit is contained in:
yumoqing 2024-08-09 15:43:28 +08:00
parent d1f3b10569
commit 504dfb3157

View File

@ -1,47 +1,25 @@
from traceback import print_exc
import base64
import wave
from appPublic.log import info, debug, warning, error, exception, critical
from appPublic.dictObject import DictObject
from appPublic.folderUtils import temp_file
from ahserver.serverenv import ServerEnv
from aiohttp.web import StreamResponse
from io import BytesIO
import struct
def save_base64_wav(base64_data, output_file,sample_rate=16000, num_channels=1):
# Decode the base64 data
wav_data = base64.b64decode(base64_data)
def audio_dic2list(audio):
ks = [k for k in audio.keys()]
info(f'{type(audio)}, {ks=}')
ks.sort()
return [audio[k] for k in ks]
# Open a new WAV file for writing
with wave.open(output_file, 'wb') as wf:
# Set the parameters of the WAV file
wf.setnchannels(num_channels) # Mono channel
wf.setsampwidth(2) # 16-bit sample width
wf.setframerate(sample_rate) # 44.1 kHz sample rate
def float32array_to_wav(samples, sample_rate=16000, num_channels=1):
# Calculate the total number of samples
num_samples = len(samples)
# Calculate the byte rate
byte_rate = sample_rate * num_channels * 4
# Calculate the block align
block_align = num_channels * 4
# Create the WAV header
header = struct.pack(
'<4sI4s4sIHHIIHH4sI',
b'RIFF', 36 + num_samples * 4, b'WAVE', b'fmt ', 16, 3, num_channels, sample_rate,
byte_rate, block_align, 32, b'data', num_samples * 4
)
# info(f'float32array_to_wav({samples[:10]}, ...)')
# Convert the Float32Array to bytes
data = struct.pack('f' * num_samples, *samples)
# Write the header and data to a file
tmpfile = temp_file(suffix='.wav')
with open(tmpfile, 'wb') as f:
f.write(header)
f.write(data)
return tmpfile
# Write the decoded data to the WAV file
wf.writeframes(wav_data)
async def generate(request, **kw):
params_kw = kw.get('params_kw', DictObject())
@ -52,6 +30,8 @@ async def generate(request, **kw):
'status':'error',
'message':'audio is null'
}
fname = temp_file(suffix='.wav')
save_base64_wav(audio, fname)
engine = None
g = ServerEnv()
if model=='whisper':
@ -63,8 +43,6 @@ async def generate(request, **kw):
'message':f'model={model} is not defined'
}
try:
audio = audio_dic2list(audio)
fname = float32array_to_wav(audio)
txt = await engine.stt(fname)
os.remove(fname)
info(f'{txt=}')