This commit is contained in:
yumoqing 2024-08-29 18:45:36 +08:00
parent 2476b7b76c
commit 64fee77657
2 changed files with 52 additions and 16 deletions

View File

@ -75,7 +75,7 @@ class RTCLLM:
peer = self.peers[peerid] peer = self.peers[peerid]
pc = peer.pc pc = peer.pc
if track.kind == 'audio': if track.kind == 'audio':
vadtrack = AudioTrackVad(track, onvoiceend=self.vad_voiceend) vadtrack = AudioTrackVad(track, stage=3, onvoiceend=self.vad_voiceend)
peer.vadtrack = vadtrack peer.vadtrack = vadtrack
vadtrack.start_vad() vadtrack.start_vad()
@ -103,7 +103,7 @@ class RTCLLM:
if pc is None: if pc is None:
print(f'{self.peers=}, {data=}') print(f'{self.peers=}, {data=}')
return return
pc.on("connectionState", partial(self.pc_connectionState_changed, data['from'].id)) pc.on("connectionstate", partial(self.pc_connectionState_changed, data['from'].id))
pc.on('track', partial(self.pc_track, data['from'].id)) pc.on('track', partial(self.pc_track, data['from'].id))
pc.on('icecandidate', partial(self.on_icecandidate, pc)) pc.on('icecandidate', partial(self.on_icecandidate, pc))
offer = RTCSessionDescription(** data.offer) offer = RTCSessionDescription(** data.offer)

View File

@ -1,10 +1,14 @@
from traceback import print_exc
import asyncio import asyncio
import collections import collections
import contextlib import contextlib
from appPublic.folderUtils import temp_file
from aiortc import MediaStreamTrack from aiortc import MediaStreamTrack
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
import webrtcvad import webrtcvad
import wave import wave
import numpy as np
from av import AudioLayout, AudioResampler, AudioFrame, AudioFormat
class AudioTrackVad(MediaStreamTrack): class AudioTrackVad(MediaStreamTrack):
def __init__(self, track, stage=3, onvoiceend=None): def __init__(self, track, stage=3, onvoiceend=None):
@ -17,12 +21,13 @@ class AudioTrackVad(MediaStreamTrack):
# frameSize = self.track.getSettings().frameSize # frameSize = self.track.getSettings().frameSize
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate # self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
self.frame_duration_ms = 0.00008 self.frame_duration_ms = 0.00008
self.num_padding_frames = 10 self.num_padding_frames = 20
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False self.triggered = False
self.voiced_frames = [] self.voiced_frames = []
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.task = None self.task = None
self.debug = True
self.running = False self.running = False
def start_vad(self): def start_vad(self):
@ -31,22 +36,47 @@ class AudioTrackVad(MediaStreamTrack):
def _recv(self): def _recv(self):
asyncio.create_task(self.recv()) asyncio.create_task(self.recv())
if self.task:
self.task.cancel()
if self.running:
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
def stop(self): def stop(self):
self.running = False self.running = False
def frame2bytes(self, frame):
# 假设你有一个 AudioFrame 对象 audio_frame
audio_array = frame.to_ndarray()
# 将 numpy 数组转换为字节数组
dtype = audio_array.dtype
audio_bytes = audio_array.tobytes()
return audio_bytes
async def recv(self): async def recv(self):
f = await self.track.recv() oldf = await self.track.recv()
print(f'{f.pts=}, {f.rate=}, {f.sample_rate=}, {f.format=}, {f.dts=}, {f.samples=}') frames = self.resample(oldf)
self.vad_check(f) for f in frames:
if self.debug:
self.debug = False
print(f'{type(f)}, {f.samples=}, {f.format.bytes=}, {f.sample_rate=}, {f.format=}, {f.is_corrupt=}, {f.layout=}, {f.planes=}, {f.side_data=}')
self.sample_rate = f.sample_rate
try:
await self.vad_check(f)
except Exception as e:
print(f'{e=}')
print_exc()
return
if self.task:
self.task.cancel()
if self.running:
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
return f return f
def resample(self, frame):
fmt = AudioFormat('s16')
al = AudioLayout(1)
r = AudioResampler(format=fmt, layout=al, rate=frame.rate)
frame = r.resample(frame)
return frame
async def vad_check(self, frame): async def vad_check(self, frame):
is_speech = self.vad.is_speech(frame, self.sample_rate) is_speech = self.vad.is_speech(self.frame2bytes(frame), self.sample_rate)
if not self.triggered: if not self.triggered:
self.ring_buffer.append((frame, is_speech)) self.ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in self.ring_buffer if speech]) num_voiced = len([f for f, speech in self.ring_buffer if speech])
@ -61,6 +91,7 @@ class AudioTrackVad(MediaStreamTrack):
for f, s in self.ring_buffer: for f, s in self.ring_buffer:
self.voiced_frames.append(f) self.voiced_frames.append(f)
self.ring_buffer.clear() self.ring_buffer.clear()
print('start voice .....', len(self.voiced_frames))
else: else:
# We're in the TRIGGERED state, so collect the audio data # We're in the TRIGGERED state, so collect the audio data
# and add it to the ring buffer. # and add it to the ring buffer.
@ -72,23 +103,28 @@ class AudioTrackVad(MediaStreamTrack):
# audio we've collected. # audio we've collected.
if num_unvoiced > 0.9 * self.ring_buffer.maxlen: if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = False self.triggered = False
audio_data = b''.join([f.bytes for f in voiced_frames]) audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames])
self.write_wave(audio_data) await self.write_wave(audio_data)
self.ring_buffer.clear() self.ring_buffer.clear()
voiced_frames = [] self.voiced_frames = []
print('end voice .....', len(self.voiced_frames))
async def write_wave(self, audio_data): async def write_wave(self, audio_data):
"""Writes a .wav file. """Writes a .wav file.
Takes path, PCM audio data, and sample rate. Takes path, PCM audio data, and sample rate.
""" """
path = make_temp(subfix='.wav') path = temp_file(suffix='.wav')
print(f'temp_file={path}')
with contextlib.closing(wave.open(path, 'wb')) as wf: with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1) wf.setnchannels(1)
wf.setsampwidth(2) wf.setsampwidth(2)
wf.setframerate(self.sample_rate) wf.setframerate(self.sample_rate)
wf.writeframes(audio) wf.writeframes(audio_data)
print('************wrote*******')
if self.onvoiceend: if self.onvoiceend:
await self.onvoiceend(path) await self.onvoiceend(path)
print('************over*******')