diff --git a/rtcllm/rtc.py b/rtcllm/rtc.py index d8bea47..276e037 100644 --- a/rtcllm/rtc.py +++ b/rtcllm/rtc.py @@ -75,7 +75,7 @@ class RTCLLM: peer = self.peers[peerid] pc = peer.pc if track.kind == 'audio': - vadtrack = AudioTrackVad(track, onvoiceend=self.vad_voiceend) + vadtrack = AudioTrackVad(track, stage=3, onvoiceend=self.vad_voiceend) peer.vadtrack = vadtrack vadtrack.start_vad() @@ -103,7 +103,7 @@ class RTCLLM: if pc is None: print(f'{self.peers=}, {data=}') return - pc.on("connectionState", partial(self.pc_connectionState_changed, data['from'].id)) + pc.on("connectionstate", partial(self.pc_connectionState_changed, data['from'].id)) pc.on('track', partial(self.pc_track, data['from'].id)) pc.on('icecandidate', partial(self.on_icecandidate, pc)) offer = RTCSessionDescription(** data.offer) diff --git a/rtcllm/vad.py b/rtcllm/vad.py index 4bd2021..dd60c0c 100644 --- a/rtcllm/vad.py +++ b/rtcllm/vad.py @@ -1,10 +1,14 @@ +from traceback import print_exc import asyncio import collections import contextlib +from appPublic.folderUtils import temp_file from aiortc import MediaStreamTrack from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay import webrtcvad import wave +import numpy as np +from av import AudioLayout, AudioResampler, AudioFrame, AudioFormat class AudioTrackVad(MediaStreamTrack): def __init__(self, track, stage=3, onvoiceend=None): @@ -17,12 +21,13 @@ class AudioTrackVad(MediaStreamTrack): # frameSize = self.track.getSettings().frameSize # self.frame_duration_ms = (1000 * frameSize) / self.sample_rate self.frame_duration_ms = 0.00008 - self.num_padding_frames = 10 + self.num_padding_frames = 20 self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.triggered = False self.voiced_frames = [] self.loop = asyncio.get_event_loop() self.task = None + self.debug = True self.running = False def start_vad(self): @@ -31,22 +36,47 @@ class AudioTrackVad(MediaStreamTrack): def _recv(self): asyncio.create_task(self.recv()) - if self.task: - self.task.cancel() - if self.running: - self.task = self.loop.call_later(self.frame_duration_ms, self._recv) def stop(self): self.running = False + def frame2bytes(self, frame): + # 假设你有一个 AudioFrame 对象 audio_frame + audio_array = frame.to_ndarray() + # 将 numpy 数组转换为字节数组 + dtype = audio_array.dtype + audio_bytes = audio_array.tobytes() + return audio_bytes + async def recv(self): - f = await self.track.recv() - print(f'{f.pts=}, {f.rate=}, {f.sample_rate=}, {f.format=}, {f.dts=}, {f.samples=}') - self.vad_check(f) + oldf = await self.track.recv() + frames = self.resample(oldf) + for f in frames: + if self.debug: + self.debug = False + print(f'{type(f)}, {f.samples=}, {f.format.bytes=}, {f.sample_rate=}, {f.format=}, {f.is_corrupt=}, {f.layout=}, {f.planes=}, {f.side_data=}') + self.sample_rate = f.sample_rate + try: + await self.vad_check(f) + except Exception as e: + print(f'{e=}') + print_exc() + return + if self.task: + self.task.cancel() + if self.running: + self.task = self.loop.call_later(self.frame_duration_ms, self._recv) return f + def resample(self, frame): + fmt = AudioFormat('s16') + al = AudioLayout(1) + r = AudioResampler(format=fmt, layout=al, rate=frame.rate) + frame = r.resample(frame) + return frame + async def vad_check(self, frame): - is_speech = self.vad.is_speech(frame, self.sample_rate) + is_speech = self.vad.is_speech(self.frame2bytes(frame), self.sample_rate) if not self.triggered: self.ring_buffer.append((frame, is_speech)) num_voiced = len([f for f, speech in self.ring_buffer if speech]) @@ -61,6 +91,7 @@ class AudioTrackVad(MediaStreamTrack): for f, s in self.ring_buffer: self.voiced_frames.append(f) self.ring_buffer.clear() + print('start voice .....', len(self.voiced_frames)) else: # We're in the TRIGGERED state, so collect the audio data # and add it to the ring buffer. @@ -72,23 +103,28 @@ class AudioTrackVad(MediaStreamTrack): # audio we've collected. if num_unvoiced > 0.9 * self.ring_buffer.maxlen: self.triggered = False - audio_data = b''.join([f.bytes for f in voiced_frames]) - self.write_wave(audio_data) + audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames]) + await self.write_wave(audio_data) self.ring_buffer.clear() - voiced_frames = [] + self.voiced_frames = [] + print('end voice .....', len(self.voiced_frames)) async def write_wave(self, audio_data): """Writes a .wav file. Takes path, PCM audio data, and sample rate. """ - path = make_temp(subfix='.wav') + path = temp_file(suffix='.wav') + print(f'temp_file={path}') + with contextlib.closing(wave.open(path, 'wb')) as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(self.sample_rate) - wf.writeframes(audio) + wf.writeframes(audio_data) + print('************wrote*******') if self.onvoiceend: await self.onvoiceend(path) + print('************over*******')