This commit is contained in:
yumoqing 2024-09-03 16:12:34 +08:00
parent 096c2c4d8e
commit 81c8c08490
3 changed files with 30 additions and 12 deletions

View File

@ -1,6 +1,9 @@
import asyncio import asyncio
import random import random
import json import json
from faster_whisper import WhisperModel
from functools import partial from functools import partial
from appPublic.dictObject import DictObject from appPublic.dictObject import DictObject
@ -19,6 +22,7 @@ videos = ['./1.mp4', './2.mp4']
class RTCLLM: class RTCLLM:
def __init__(self, ws_url, iceServers): def __init__(self, ws_url, iceServers):
self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16")
self.ws_url = ws_url self.ws_url = ws_url
self.iceServers = iceServers self.iceServers = iceServers
self.peers = DictObject() self.peers = DictObject()
@ -69,7 +73,7 @@ class RTCLLM:
self.onlineList = data.onlineList self.onlineList = data.onlineList
async def vad_voiceend(self, peer, audio): async def vad_voiceend(self, peer, audio):
txt = await asr(audio) ret = await asr(self.stt_model, audio)
# await peer.dc.send(txt) # await peer.dc.send(txt)

View File

@ -1,5 +1,6 @@
from appPublic.dictObject import DictObject from appPublic.dictObject import DictObject
from appPublic.oauth_client import OAuthClient from appPublic.oauth_client import OAuthClient
from faster_whisoer import WhisperModel
desc = { desc = {
"path":"/asr/generate", "path":"/asr/generate",
@ -30,8 +31,17 @@ opts = {
"asr":desc "asr":desc
} }
async def asr(audio): async def asr(model, a_file):
"""
oc = OAuthClient(DictObject(**opts)) oc = OAuthClient(DictObject(**opts))
r = await oc("http://open-computing.cn", "asr", {"b64audio":audio}) r = await oc("http://open-computing.cn", "asr", {"b64audio":audio})
print(f'{r=}') print(f'{r=}')
"""
segments, info = model.transcribe(a_file, beam_size=5)
txt = ''
for s in segments:
txt += s.text
return {
'content': txt,
'language': info.language
}

View File

@ -104,35 +104,39 @@ class AudioTrackVad(MediaStreamTrack):
# audio we've collected. # audio we've collected.
if num_unvoiced > 0.9 * self.ring_buffer.maxlen: if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = False self.triggered = False
# audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames]) ret = await self.write_wave()
# await self.write_wave(audio_data) # ret = await self.gen_base64()
await self.gen_base64() if self.onvoiceend:
await self.onvoiceend(ret)
self.ring_buffer.clear() self.ring_buffer.clear()
self.voiced_frames = [] self.voiced_frames = []
print('end voice .....', len(self.voiced_frames)) def to_mono16000_data(self):
async def gen_base64(self):
lst = [] lst = []
for f in self.voiced_frames: for f in self.voiced_frames:
fs = self.resample(f, sample_rate=16000) fs = self.resample(f, sample_rate=16000)
lst += fs lst += fs
audio_data = b''.join([self.frame2bytes(f) for f in lst]) audio_data = b''.join([self.frame2bytes(f) for f in lst])
return to_mono16000_data
async def gen_base64(self):
audio_data = self.to_mono16000_data()
b64 = base64.b64encode(audio_data).decode('utf-8') b64 = base64.b64encode(audio_data).decode('utf-8')
if self.onvoiceend: return b64
await self.onvoiceend(b64)
async def write_wave(self, audio_data): async def write_wave(self):
"""Writes a .wav file. """Writes a .wav file.
Takes path, PCM audio data, and sample rate. Takes path, PCM audio data, and sample rate.
""" """
audio_data = self.to_mono16000_data()
path = temp_file(suffix='.wav') path = temp_file(suffix='.wav')
print(f'temp_file={path}') print(f'temp_file={path}')
with contextlib.closing(wave.open(path, 'wb')) as wf: with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1) wf.setnchannels(1)
wf.setsampwidth(2) wf.setsampwidth(2)
wf.setframerate(self.sample_rate) wf.setframerate(16000)
wf.writeframes(audio_data) wf.writeframes(audio_data)
print('************wrote*******') print('************wrote*******')