import base64 from traceback import print_exc import asyncio import collections import contextlib from appPublic.folderUtils import temp_file from aiortc import MediaStreamTrack from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay import webrtcvad import wave import numpy as np from av import AudioLayout, AudioResampler, AudioFrame, AudioFormat class AudioTrackVad(MediaStreamTrack): def __init__(self, track, stage=3, onvoiceend=None): super().__init__() self.track = track self.onvoiceend = onvoiceend self.vad = webrtcvad.Vad(stage) # self.sample_rate = self.track.getSettings().sampleRate # frameSize = self.track.getSettings().frameSize # self.frame_duration_ms = (1000 * frameSize) / self.sample_rate self.frame_duration_ms = 0.02 self.num_padding_frames = 20 self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.triggered = False self.voiced_frames = [] self.loop = asyncio.get_event_loop() self.task = None self.debug = True self.running = False def start_vad(self): self.running = True self.task = self.loop.call_later(self.frame_duration_ms, self._recv) def _recv(self): asyncio.create_task(self.recv()) def stop(self): self.running = False def frame2bytes(self, frame): audio_array = frame.to_ndarray() dtype = audio_array.dtype audio_bytes = audio_array.tobytes() return audio_bytes async def recv(self): frame = await self.track.recv() self.sample_rate = frame.sample_rate duration = (frame.samples * 1000) / frame.sample_rate # print(f'{self.__class__.__name__}.recv(): {duration=}, {frame.samples=}, {frame.sample_rate=}') try: await self.vad_check(frame) except Exception as e: print(f'{e=}') print_exc() return if self.task: self.task.cancel() if self.running: self.task = self.loop.call_later(self.frame_duration_ms, self._recv) return frame def resample(self, frame, sample_rate=None): if sample_rate is None: sample_rate = frame.rate r = AudioResampler(format='s16', layout='mono', rate=sample_rate) frame = r.resample(frame) return frame async def vad_check(self, inframe): frames = self.resample(inframe) frame = frames[0] is_speech = self.vad.is_speech(self.frame2bytes(frame), self.sample_rate) if not self.triggered: self.ring_buffer.append((inframe, is_speech)) num_voiced = len([f for f, speech in self.ring_buffer if speech]) # If we're NOTTRIGGERED and more than 90% of the frames in # the ring buffer are voiced frames, then enter the # TRIGGERED state. if num_voiced > 0.9 * self.ring_buffer.maxlen: self.triggered = True # We want to yield all the audio we see from now until # we are NOTTRIGGERED, but we have to start with the # audio that's already in the ring buffer. for f, s in self.ring_buffer: self.voiced_frames.append(f) self.ring_buffer.clear() # print('start voice .....', len(self.voiced_frames)) else: # We're in the TRIGGERED state, so collect the audio data # and add it to the ring buffer. self.voiced_frames.append(inframe) self.ring_buffer.append((frame, is_speech)) num_unvoiced = len([f for f, speech in self.ring_buffer if not speech]) # If more than 90% of the frames in the ring buffer are # unvoiced, then enter NOTTRIGGERED and yield whatever # audio we've collected. if num_unvoiced > 0.9 * self.ring_buffer.maxlen: self.triggered = False duration = self.voice_duration() if duration > 500 and self.onvoiceend: ret = await self.write_wave() await self.onvoiceend(ret) else: print(f'vad sound {duration=}') self.ring_buffer.clear() self.voiced_frames = [] def to_mono16000_data(self): lst = [] for f in self.voiced_frames: fs = self.resample(f, sample_rate=16000) lst += fs audio_data = b''.join([self.frame2bytes(f) for f in lst]) return audio_data async def gen_base64(self): audio_data = self.to_mono16000_data() b64 = base64.b64encode(audio_data).decode('utf-8') return b64 def voice_duration(self): duration = 0 for f in self.voiced_frames: duration = f.samples * 1000 / f.sample_rate + duration return duration async def write_wave(self): """Writes a .wav file. Takes path, PCM audio data, and sample rate. """ audio_data = self.to_mono16000_data() path = temp_file(suffix='.wav') # print(f'temp_file={path}') with contextlib.closing(wave.open(path, 'wb')) as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(16000) wf.writeframes(audio_data) # print('************wrote*******') if self.onvoiceend: await self.onvoiceend(path) # print('************over*******')