diff --git a/rtcllm/vad.py b/rtcllm/vad.py index 7c32c01..56a0da6 100644 --- a/rtcllm/vad.py +++ b/rtcllm/vad.py @@ -72,7 +72,7 @@ class AudioTrackVad(MediaStreamTrack): async def vad_check(self, inframe): frames = self.resample(inframe) - frame = frames[0]: + frame = frames[0] is_speech = self.vad.is_speech(self.frame2bytes(frame), self.sample_rate) if not self.triggered: @@ -101,10 +101,13 @@ class AudioTrackVad(MediaStreamTrack): # audio we've collected. if num_unvoiced > 0.9 * self.ring_buffer.maxlen: self.triggered = False - ret = await self.write_wave() - # ret = await self.gen_base64() - if self.onvoiceend: + duration = self.voice_duration() + if duration > 500 and self.onvoiceend: + ret = await self.write_wave() await self.onvoiceend(ret) + else: + print(f'{duration=} {self.onvoiceend=}') + self.ring_buffer.clear() self.voiced_frames = [] @@ -120,12 +123,18 @@ class AudioTrackVad(MediaStreamTrack): audio_data = self.to_mono16000_data() b64 = base64.b64encode(audio_data).decode('utf-8') return b64 + + def voice_duration(self): + duration = 0 + [ f.samples * 1000 / f.sample_rate + duration for f in self.self.voiced_frames ] + return duration async def write_wave(self): """Writes a .wav file. Takes path, PCM audio data, and sample rate. """ + audio_data = self.to_mono16000_data() path = temp_file(suffix='.wav') # print(f'temp_file={path}')