rtcllm/rtcllm/vad.py
2024-09-15 00:44:52 +08:00

206 lines
6.2 KiB
Python

import base64
from traceback import print_exc
import asyncio
import collections
import contextlib
from appPublic.folderUtils import temp_file
from aiortc import MediaStreamTrack
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
import webrtcvad
import wave
from scipy.io.wavfile import write
import numpy as np
import av
from av import AudioLayout, AudioResampler, AudioFrame, AudioFormat
class AudioTrackVad(MediaStreamTrack):
def __init__(self, track, stage=3, onvoiceend=None):
super().__init__()
self.track = track
self.onvoiceend = onvoiceend
self.vad = webrtcvad.Vad(stage)
# self.sample_rate = self.track.getSettings().sampleRate
# frameSize = self.track.getSettings().frameSize
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
self.frame_duration_ms = 0.02
self.num_padding_frames = 20
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False
self.voiced_frames = []
self.loop = asyncio.get_event_loop()
self.task = None
self.debug = True
self.running = False
def start_vad(self):
self.running = True
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
def _recv(self):
asyncio.create_task(self.recv())
def stop(self):
self.running = False
def frame2bytes(self, frame):
audio_array = frame.to_ndarray()
dtype = audio_array.dtype
audio_bytes = audio_array.tobytes()
return audio_bytes
async def recv(self):
frame = await self.track.recv()
self.sample_rate = frame.sample_rate
duration = (frame.samples * 1000) / frame.sample_rate
# print(f'{self.__class__.__name__}.recv(): {duration=}, {frame.samples=}, {frame.sample_rate=}')
try:
await self.vad_check(frame)
except Exception as e:
print(f'{e=}')
print_exc()
return
if self.task:
self.task.cancel()
if self.running:
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
return frame
def resample(self, frame, sample_rate=None):
if sample_rate is None:
sample_rate = frame.rate
r = AudioResampler(format='s16', layout='mono', rate=sample_rate)
frame = r.resample(frame)
return frame
async def vad_check(self, inframe):
frames = self.resample(inframe)
frame = frames[0]
is_speech = self.vad.is_speech(self.frame2bytes(frame),
self.sample_rate)
if not self.triggered:
self.ring_buffer.append((inframe, is_speech))
num_voiced = len([f for f, speech in self.ring_buffer if speech])
# If we're NOTTRIGGERED and more than 90% of the frames in
# the ring buffer are voiced frames, then enter the
# TRIGGERED state.
if num_voiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = True
# We want to yield all the audio we see from now until
# we are NOTTRIGGERED, but we have to start with the
# audio that's already in the ring buffer.
for f, s in self.ring_buffer:
self.voiced_frames.append(f)
self.ring_buffer.clear()
# print('start voice .....', len(self.voiced_frames))
else:
# We're in the TRIGGERED state, so collect the audio data
# and add it to the ring buffer.
self.voiced_frames.append(inframe)
self.ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
# If more than 90% of the frames in the ring buffer are
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = False
duration = self.voice_duration()
if duration > 500 and self.onvoiceend:
ret = await self.write_wave()
await self.onvoiceend(ret)
else:
print(f'vad sound {duration=}')
self.ring_buffer.clear()
self.voiced_frames = []
def to_mono16000_data(self):
lst = []
for f in self.voiced_frames:
fs = self.resample(f, sample_rate=16000)
lst += fs
audio_data = b''.join([self.frame2bytes(f) for f in lst])
return audio_data
async def gen_base64(self):
audio_data = self.to_mono16000_data()
b64 = base64.b64encode(audio_data).decode('utf-8')
return b64
def voice_duration(self):
duration = 0
for f in self.voiced_frames:
duration = f.samples * 1000 / f.sample_rate + duration
return duration
def frames_resample(self, frames, sr=None):
fs = []
for f in frames:
fs1 = self.resample(f, sample_rate=sr)
fs += fs1
return fs
def frames_write_wave(self, frames):
"""
fb = [ self.frame2bytes(f) for f in frames ]
ndarr = np.frombuffer(b''.join(fb), dtype=np.int16)
fn = temp_file(suffix='.wav')
write(fn, frames[0].sample_rate, ndarr)
return fn
"""
path = temp_file(suffix='.wav')
output_container = av.open(path, 'w')
out_stream = output_container.add_stream('pcm_s16le')
for frame in frames:
for packet in out_stream.encode(frame):
output_container.mux(packet)
for packet in out_stream.encode(None):
output_container.mux(packet)
output_container.close()
return path
async def write_wave(self):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
"""
############
# Method:1
############
audio_data = self.to_mono16000_data()
path = temp_file(suffix='.wav')
# print(f'temp_file={path}')
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(audio_data)
# print('************wrote*******')
if self.onvoiceend:
await self.onvoiceend(path)
# print('************over*******')
return
############
# Method:2
############
path = temp_file(suffix='.wav')
output_container = av.open(path, 'w')
out_stream = output_container.add_stream('pcm_s16le', rate=16000, layout='mono')
resampler = AudioResampler(format=out_stream.format, layout=out_stream.layout, rate=out_stream.rate)
for frame in self.voiced_frames:
for f in resampler.resample(frame):
output_container.mux(out_stream.encode(f))
output_container.mux(out_stream.encode())
output_container.close()
return path
"""
f1 = self.frames_write_wave(self.voiced_frames)
frames = self.frames_resample(self.voiced_frames, sr=16000)
fn = self.frames_write_wave(frames)
print(f'source wave filename={f1}, mono 16000 wave filename={fn}')
return fn