diff --git a/rtcllm/rtc.py b/rtcllm/rtc.py index e9c55a0..ff18d4d 100644 --- a/rtcllm/rtc.py +++ b/rtcllm/rtc.py @@ -1,6 +1,9 @@ import asyncio import random import json + +from faster_whisper import WhisperModel + from functools import partial from appPublic.dictObject import DictObject @@ -19,6 +22,7 @@ videos = ['./1.mp4', './2.mp4'] class RTCLLM: def __init__(self, ws_url, iceServers): + self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16") self.ws_url = ws_url self.iceServers = iceServers self.peers = DictObject() @@ -69,7 +73,7 @@ class RTCLLM: self.onlineList = data.onlineList async def vad_voiceend(self, peer, audio): - txt = await asr(audio) + ret = await asr(self.stt_model, audio) # await peer.dc.send(txt) diff --git a/rtcllm/stt.py b/rtcllm/stt.py index c306e51..a47bcf9 100644 --- a/rtcllm/stt.py +++ b/rtcllm/stt.py @@ -1,5 +1,6 @@ from appPublic.dictObject import DictObject from appPublic.oauth_client import OAuthClient +from faster_whisoer import WhisperModel desc = { "path":"/asr/generate", @@ -30,8 +31,17 @@ opts = { "asr":desc } -async def asr(audio): +async def asr(model, a_file): + """ oc = OAuthClient(DictObject(**opts)) r = await oc("http://open-computing.cn", "asr", {"b64audio":audio}) print(f'{r=}') - + """ + segments, info = model.transcribe(a_file, beam_size=5) + txt = '' + for s in segments: + txt += s.text + return { + 'content': txt, + 'language': info.language + } diff --git a/rtcllm/vad.py b/rtcllm/vad.py index e7ad2a0..114dcda 100644 --- a/rtcllm/vad.py +++ b/rtcllm/vad.py @@ -104,35 +104,39 @@ class AudioTrackVad(MediaStreamTrack): # audio we've collected. if num_unvoiced > 0.9 * self.ring_buffer.maxlen: self.triggered = False - # audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames]) - # await self.write_wave(audio_data) - await self.gen_base64() + ret = await self.write_wave() + # ret = await self.gen_base64() + if self.onvoiceend: + await self.onvoiceend(ret) self.ring_buffer.clear() self.voiced_frames = [] - print('end voice .....', len(self.voiced_frames)) - async def gen_base64(self): + def to_mono16000_data(self): lst = [] for f in self.voiced_frames: fs = self.resample(f, sample_rate=16000) lst += fs audio_data = b''.join([self.frame2bytes(f) for f in lst]) + return to_mono16000_data + + async def gen_base64(self): + audio_data = self.to_mono16000_data() b64 = base64.b64encode(audio_data).decode('utf-8') - if self.onvoiceend: - await self.onvoiceend(b64) + return b64 - async def write_wave(self, audio_data): + async def write_wave(self): """Writes a .wav file. Takes path, PCM audio data, and sample rate. """ + audio_data = self.to_mono16000_data() path = temp_file(suffix='.wav') print(f'temp_file={path}') with contextlib.closing(wave.open(path, 'wb')) as wf: wf.setnchannels(1) wf.setsampwidth(2) - wf.setframerate(self.sample_rate) + wf.setframerate(16000) wf.writeframes(audio_data) print('************wrote*******')