bugfix
This commit is contained in:
parent
2476b7b76c
commit
64fee77657
@ -75,7 +75,7 @@ class RTCLLM:
|
|||||||
peer = self.peers[peerid]
|
peer = self.peers[peerid]
|
||||||
pc = peer.pc
|
pc = peer.pc
|
||||||
if track.kind == 'audio':
|
if track.kind == 'audio':
|
||||||
vadtrack = AudioTrackVad(track, onvoiceend=self.vad_voiceend)
|
vadtrack = AudioTrackVad(track, stage=3, onvoiceend=self.vad_voiceend)
|
||||||
peer.vadtrack = vadtrack
|
peer.vadtrack = vadtrack
|
||||||
vadtrack.start_vad()
|
vadtrack.start_vad()
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ class RTCLLM:
|
|||||||
if pc is None:
|
if pc is None:
|
||||||
print(f'{self.peers=}, {data=}')
|
print(f'{self.peers=}, {data=}')
|
||||||
return
|
return
|
||||||
pc.on("connectionState", partial(self.pc_connectionState_changed, data['from'].id))
|
pc.on("connectionstate", partial(self.pc_connectionState_changed, data['from'].id))
|
||||||
pc.on('track', partial(self.pc_track, data['from'].id))
|
pc.on('track', partial(self.pc_track, data['from'].id))
|
||||||
pc.on('icecandidate', partial(self.on_icecandidate, pc))
|
pc.on('icecandidate', partial(self.on_icecandidate, pc))
|
||||||
offer = RTCSessionDescription(** data.offer)
|
offer = RTCSessionDescription(** data.offer)
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
|
from traceback import print_exc
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from appPublic.folderUtils import temp_file
|
||||||
from aiortc import MediaStreamTrack
|
from aiortc import MediaStreamTrack
|
||||||
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||||
import webrtcvad
|
import webrtcvad
|
||||||
import wave
|
import wave
|
||||||
|
import numpy as np
|
||||||
|
from av import AudioLayout, AudioResampler, AudioFrame, AudioFormat
|
||||||
|
|
||||||
class AudioTrackVad(MediaStreamTrack):
|
class AudioTrackVad(MediaStreamTrack):
|
||||||
def __init__(self, track, stage=3, onvoiceend=None):
|
def __init__(self, track, stage=3, onvoiceend=None):
|
||||||
@ -17,12 +21,13 @@ class AudioTrackVad(MediaStreamTrack):
|
|||||||
# frameSize = self.track.getSettings().frameSize
|
# frameSize = self.track.getSettings().frameSize
|
||||||
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
|
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
|
||||||
self.frame_duration_ms = 0.00008
|
self.frame_duration_ms = 0.00008
|
||||||
self.num_padding_frames = 10
|
self.num_padding_frames = 20
|
||||||
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
||||||
self.triggered = False
|
self.triggered = False
|
||||||
self.voiced_frames = []
|
self.voiced_frames = []
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
self.task = None
|
self.task = None
|
||||||
|
self.debug = True
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def start_vad(self):
|
def start_vad(self):
|
||||||
@ -31,22 +36,47 @@ class AudioTrackVad(MediaStreamTrack):
|
|||||||
|
|
||||||
def _recv(self):
|
def _recv(self):
|
||||||
asyncio.create_task(self.recv())
|
asyncio.create_task(self.recv())
|
||||||
if self.task:
|
|
||||||
self.task.cancel()
|
|
||||||
if self.running:
|
|
||||||
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
|
def frame2bytes(self, frame):
|
||||||
|
# 假设你有一个 AudioFrame 对象 audio_frame
|
||||||
|
audio_array = frame.to_ndarray()
|
||||||
|
# 将 numpy 数组转换为字节数组
|
||||||
|
dtype = audio_array.dtype
|
||||||
|
audio_bytes = audio_array.tobytes()
|
||||||
|
return audio_bytes
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
f = await self.track.recv()
|
oldf = await self.track.recv()
|
||||||
print(f'{f.pts=}, {f.rate=}, {f.sample_rate=}, {f.format=}, {f.dts=}, {f.samples=}')
|
frames = self.resample(oldf)
|
||||||
self.vad_check(f)
|
for f in frames:
|
||||||
|
if self.debug:
|
||||||
|
self.debug = False
|
||||||
|
print(f'{type(f)}, {f.samples=}, {f.format.bytes=}, {f.sample_rate=}, {f.format=}, {f.is_corrupt=}, {f.layout=}, {f.planes=}, {f.side_data=}')
|
||||||
|
self.sample_rate = f.sample_rate
|
||||||
|
try:
|
||||||
|
await self.vad_check(f)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'{e=}')
|
||||||
|
print_exc()
|
||||||
|
return
|
||||||
|
if self.task:
|
||||||
|
self.task.cancel()
|
||||||
|
if self.running:
|
||||||
|
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
def resample(self, frame):
|
||||||
|
fmt = AudioFormat('s16')
|
||||||
|
al = AudioLayout(1)
|
||||||
|
r = AudioResampler(format=fmt, layout=al, rate=frame.rate)
|
||||||
|
frame = r.resample(frame)
|
||||||
|
return frame
|
||||||
|
|
||||||
async def vad_check(self, frame):
|
async def vad_check(self, frame):
|
||||||
is_speech = self.vad.is_speech(frame, self.sample_rate)
|
is_speech = self.vad.is_speech(self.frame2bytes(frame), self.sample_rate)
|
||||||
if not self.triggered:
|
if not self.triggered:
|
||||||
self.ring_buffer.append((frame, is_speech))
|
self.ring_buffer.append((frame, is_speech))
|
||||||
num_voiced = len([f for f, speech in self.ring_buffer if speech])
|
num_voiced = len([f for f, speech in self.ring_buffer if speech])
|
||||||
@ -61,6 +91,7 @@ class AudioTrackVad(MediaStreamTrack):
|
|||||||
for f, s in self.ring_buffer:
|
for f, s in self.ring_buffer:
|
||||||
self.voiced_frames.append(f)
|
self.voiced_frames.append(f)
|
||||||
self.ring_buffer.clear()
|
self.ring_buffer.clear()
|
||||||
|
print('start voice .....', len(self.voiced_frames))
|
||||||
else:
|
else:
|
||||||
# We're in the TRIGGERED state, so collect the audio data
|
# We're in the TRIGGERED state, so collect the audio data
|
||||||
# and add it to the ring buffer.
|
# and add it to the ring buffer.
|
||||||
@ -72,23 +103,28 @@ class AudioTrackVad(MediaStreamTrack):
|
|||||||
# audio we've collected.
|
# audio we've collected.
|
||||||
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
|
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
|
||||||
self.triggered = False
|
self.triggered = False
|
||||||
audio_data = b''.join([f.bytes for f in voiced_frames])
|
audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames])
|
||||||
self.write_wave(audio_data)
|
await self.write_wave(audio_data)
|
||||||
self.ring_buffer.clear()
|
self.ring_buffer.clear()
|
||||||
voiced_frames = []
|
self.voiced_frames = []
|
||||||
|
print('end voice .....', len(self.voiced_frames))
|
||||||
|
|
||||||
async def write_wave(self, audio_data):
|
async def write_wave(self, audio_data):
|
||||||
"""Writes a .wav file.
|
"""Writes a .wav file.
|
||||||
|
|
||||||
Takes path, PCM audio data, and sample rate.
|
Takes path, PCM audio data, and sample rate.
|
||||||
"""
|
"""
|
||||||
path = make_temp(subfix='.wav')
|
path = temp_file(suffix='.wav')
|
||||||
|
print(f'temp_file={path}')
|
||||||
|
|
||||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||||
wf.setnchannels(1)
|
wf.setnchannels(1)
|
||||||
wf.setsampwidth(2)
|
wf.setsampwidth(2)
|
||||||
wf.setframerate(self.sample_rate)
|
wf.setframerate(self.sample_rate)
|
||||||
wf.writeframes(audio)
|
wf.writeframes(audio_data)
|
||||||
|
|
||||||
|
print('************wrote*******')
|
||||||
if self.onvoiceend:
|
if self.onvoiceend:
|
||||||
await self.onvoiceend(path)
|
await self.onvoiceend(path)
|
||||||
|
print('************over*******')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user