rtcllm/rtcllm/vad.py
2024-09-14 15:20:11 +08:00

154 lines
4.6 KiB
Python

import base64
from traceback import print_exc
import asyncio
import collections
import contextlib
from appPublic.folderUtils import temp_file
from aiortc import MediaStreamTrack
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
import webrtcvad
import wave
import numpy as np
from av import AudioLayout, AudioResampler, AudioFrame, AudioFormat
class AudioTrackVad(MediaStreamTrack):
def __init__(self, track, stage=3, onvoiceend=None):
super().__init__()
self.track = track
self.onvoiceend = onvoiceend
self.vad = webrtcvad.Vad(stage)
# self.sample_rate = self.track.getSettings().sampleRate
# frameSize = self.track.getSettings().frameSize
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
self.frame_duration_ms = 0.02
self.num_padding_frames = 20
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False
self.voiced_frames = []
self.loop = asyncio.get_event_loop()
self.task = None
self.debug = True
self.running = False
def start_vad(self):
self.running = True
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
def _recv(self):
asyncio.create_task(self.recv())
def stop(self):
self.running = False
def frame2bytes(self, frame):
audio_array = frame.to_ndarray()
dtype = audio_array.dtype
audio_bytes = audio_array.tobytes()
return audio_bytes
async def recv(self):
frame = await self.track.recv()
self.sample_rate = frame.sample_rate
duration = (frame.samples * 1000) / frame.sample_rate
# print(f'{self.__class__.__name__}.recv(): {duration=}, {frame.samples=}, {frame.sample_rate=}')
try:
await self.vad_check(frame)
except Exception as e:
print(f'{e=}')
print_exc()
return
if self.task:
self.task.cancel()
if self.running:
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
return frame
def resample(self, frame, sample_rate=None):
if sample_rate is None:
sample_rate = frame.rate
r = AudioResampler(format='s16', layout='mono', rate=sample_rate)
frame = r.resample(frame)
return frame
async def vad_check(self, inframe):
frames = self.resample(inframe)
frame = frames[0]
is_speech = self.vad.is_speech(self.frame2bytes(frame),
self.sample_rate)
if not self.triggered:
self.ring_buffer.append((inframe, is_speech))
num_voiced = len([f for f, speech in self.ring_buffer if speech])
# If we're NOTTRIGGERED and more than 90% of the frames in
# the ring buffer are voiced frames, then enter the
# TRIGGERED state.
if num_voiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = True
# We want to yield all the audio we see from now until
# we are NOTTRIGGERED, but we have to start with the
# audio that's already in the ring buffer.
for f, s in self.ring_buffer:
self.voiced_frames.append(f)
self.ring_buffer.clear()
# print('start voice .....', len(self.voiced_frames))
else:
# We're in the TRIGGERED state, so collect the audio data
# and add it to the ring buffer.
self.voiced_frames.append(inframe)
self.ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
# If more than 90% of the frames in the ring buffer are
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = False
duration = self.voice_duration()
if duration > 500 and self.onvoiceend:
ret = await self.write_wave()
await self.onvoiceend(ret)
else:
print(f'vad sound {duration=}')
self.ring_buffer.clear()
self.voiced_frames = []
def to_mono16000_data(self):
lst = []
for f in self.voiced_frames:
fs = self.resample(f, sample_rate=16000)
lst += fs
audio_data = b''.join([self.frame2bytes(f) for f in lst])
return audio_data
async def gen_base64(self):
audio_data = self.to_mono16000_data()
b64 = base64.b64encode(audio_data).decode('utf-8')
return b64
def voice_duration(self):
duration = 0
for f in self.voiced_frames:
duration = f.samples * 1000 / f.sample_rate + duration
return duration
async def write_wave(self):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
audio_data = self.to_mono16000_data()
path = temp_file(suffix='.wav')
# print(f'temp_file={path}')
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(audio_data)
# print('************wrote*******')
if self.onvoiceend:
await self.onvoiceend(path)
# print('************over*******')