From 2476b7b76cefc2ed25a8f9939eca29991cf9ef87 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Wed, 28 Aug 2024 18:26:47 +0800 Subject: [PATCH] bugfix --- rtcllm/rtc.py | 121 +++++++++++++++++++++++++++++++++++++------------- rtcllm/vad.py | 35 ++++++++++++--- 2 files changed, 119 insertions(+), 37 deletions(-) diff --git a/rtcllm/rtc.py b/rtcllm/rtc.py index 572c75f..d8bea47 100644 --- a/rtcllm/rtc.py +++ b/rtcllm/rtc.py @@ -1,4 +1,6 @@ import asyncio +import json +from functools import partial from appPublic.dictObject import DictObject from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCIceCandidate @@ -8,65 +10,111 @@ from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, Med from websockets.asyncio.client import connect import websockets +from vad import AudioTrackVad + class RTCLLM: def __init__(self, ws_url, iceServers): self.ws_url = ws_url self.iceServers = iceServers self.peers = DictObject() + self.info = { + 'id':'rtcllm_agent', + 'name':'rtcllm_agent' + } + self.dc = None def get_pc(self, data): return self.peers[data['from'].id].pc async def run(self): async with connect(self.ws_url) as self.ws: - self.login() + await self.login() while True: msg = await self.ws.recv() data = DictObject(**json.loads(msg)) - f = self.handler.get(data.type) - await f(data) + d = data.data + print(f'ws recv(): {d.type=}') + func = self.handlers.get(d.type) + if func: + f = partial(func, self) + await f(d) self.ws.close() async def login(self): await self.ws.send(json.dumps({ 'type':'login', - 'info':{ - 'id':'rtcllm_agent', - 'name':'rtcllm_agent' - } + 'info':self.info })) + async def on_icecandidate(self, pc, candidate): + print('on_icecandidate()', self, pc, candidate) + if candidate: + await self.ws.send(json.dumps({ + "type":"candidate", + "candidate":candidate, + })) + async def save_onlineList(self, data): + print(f'{self}, {type(self)}') self.onlineList = data.onlineList + async def vad_voiceend(self, fn): + print(f'vad_voiceend():{fn=}') + async def auto_accept_call(self, data): - pc = await self.createPeerConnection(data['from']) + opts = DictObject(iceServers=self.iceServers) + pc = RTCPeerConnection(opts) self.peers[data['from'].id] = DictObject(**{ 'info':data['from'], 'pc':pc }) + print(f'{opts=}, {pc=}, {pc.myconfiguration=}') await self.ws.send(json.dumps({'type':'callAccepted', 'to':data['from']})) + async def pc_track(self, peerid, track): + peer = self.peers[peerid] + pc = peer.pc + if track.kind == 'audio': + vadtrack = AudioTrackVad(track, onvoiceend=self.vad_voiceend) + peer.vadtrack = vadtrack + vadtrack.start_vad() + + async def pc_connectionState_changed(self, peerid): + peer = self.peers[peerid] + pc = peer.pc + print(f'************************************{pc.connectionState=}') + if pc.connectionState == 'connected': + peer.dc = await pc.createDataChannel(peer.info.name) + return + if pc.connectionState == 'closed': + await pc.close() + if peer.dc: + await peer.dc.close() + + peers = { + k:v for k,v in self.peers.items() if k != peerid + } + self.peers = peers + if len([k for k in self.peers.keys()]) == 0: + await self.ws.close() + async def response_offer(self, data): pc = self.get_pc(data) - offer = RTCSessionDescription(data.offer) + if pc is None: + print(f'{self.peers=}, {data=}') + return + pc.on("connectionState", partial(self.pc_connectionState_changed, data['from'].id)) + pc.on('track', partial(self.pc_track, data['from'].id)) + pc.on('icecandidate', partial(self.on_icecandidate, pc)) + offer = RTCSessionDescription(** data.offer) await pc.setRemoteDescription(offer) answer = await pc.createAnswer() await pc.setLocalDescription(answer) - self.ws.send(json.dumps({ + await self.ws.send(json.dumps({ 'type':'answer', - 'answer':pc.localDescription, + 'answer':{'type':pc.localDescription.type, 'sdp':pc.localDescription.sdp}, 'to':data['from'] })) - @pc.on("datachannel") - def datachannel_handle(channel): - @channel.on("message") - def recvdata(channel): - pass - """ - @pc.on("connectionstatechange") - @pc.on("track") - """ async def accept_answer(self, data): pc = self.get_pc(data) @@ -75,15 +123,27 @@ class RTCLLM: async def accept_iceCandidate(self, data): pc = self.get_pc(data) - candidate = RTCIceCandidate(data.candidate) - # pc.addIceCandidate(RTCIceCandidate.from_string(candidate)) - await pc.addIceCandidate(candidate) - - async def createPeerConnection(self, peerinfo): - opts = { - iceServers:self.iceServers - } - pc = RTCPeerConnection(opts) + candidate = data.candidate + ip = candidate['candidate'].split(' ')[4] + port = candidate['candidate'].split(' ')[5] + protocol = candidate['candidate'].split(' ')[7] + priority = candidate['candidate'].split(' ')[3] + foundation = candidate['candidate'].split(' ')[0] + component = candidate['candidate'].split(' ')[1] + type = candidate['candidate'].split(' ')[7] + rtc_candidate = RTCIceCandidate( + ip=ip, + port=port, + protocol=protocol, + priority=priority, + foundation=foundation, + component=component, + type=type, + sdpMid=candidate['sdpMid'], + sdpMLineIndex=candidate['sdpMLineIndex'] + ) + await pc.addIceCandidate(rtc_candidate) + print('addIceCandidate ok') handlers = { 'onlineList':save_onlineList, @@ -94,13 +154,14 @@ class RTCLLM: } async def main(): - agent = RTCLLM(ws_url='https://sage.opencomputing.cn/wss/ws/rtc_signaling.ws', + agent = RTCLLM(ws_url='wss://sage.opencomputing.cn/wss/ws/rtc_signaling.ws', iceServers=[{ 'urls':'stun:stun.opencomputing.cn'},{ 'urls':'turn:stun.opencomputing.cn', 'username':'turn', 'credential':'server' }]) + print('running ...') await agent.run() if __name__ == '__main__': diff --git a/rtcllm/vad.py b/rtcllm/vad.py index 895077c..4bd2021 100644 --- a/rtcllm/vad.py +++ b/rtcllm/vad.py @@ -1,3 +1,4 @@ +import asyncio import collections import contextlib from aiortc import MediaStreamTrack @@ -9,20 +10,40 @@ class AudioTrackVad(MediaStreamTrack): def __init__(self, track, stage=3, onvoiceend=None): super().__init__() self.track = track + print(dir(track), 'AudioTrackVad.__init__()') self.onvoiceend = onvoiceend self.vad = webrtcvad.Vad(stage) - self.sample_rate = self.track.getSettings().sampleRate - frameSize = self.track.getSettings().frameSize - self.frame_duration_ms = (1000 * frameSize) / self.sample_rate + # self.sample_rate = self.track.getSettings().sampleRate + # frameSize = self.track.getSettings().frameSize + # self.frame_duration_ms = (1000 * frameSize) / self.sample_rate + self.frame_duration_ms = 0.00008 self.num_padding_frames = 10 self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.triggered = False self.voiced_frames = [] + self.loop = asyncio.get_event_loop() + self.task = None + self.running = False + + def start_vad(self): + self.running = True + self.task = self.loop.call_later(self.frame_duration_ms, self._recv) + def _recv(self): + asyncio.create_task(self.recv()) + if self.task: + self.task.cancel() + if self.running: + self.task = self.loop.call_later(self.frame_duration_ms, self._recv) + + def stop(self): + self.running = False + async def recv(self): - frame = await self.track.recv() - self.vad_check(frame) - return frame + f = await self.track.recv() + print(f'{f.pts=}, {f.rate=}, {f.sample_rate=}, {f.format=}, {f.dts=}, {f.samples=}') + self.vad_check(f) + return f async def vad_check(self, frame): is_speech = self.vad.is_speech(frame, self.sample_rate) @@ -69,5 +90,5 @@ class AudioTrackVad(MediaStreamTrack): wf.writeframes(audio) if self.onvoiceend: - self.onvoiceend(path) + await self.onvoiceend(path)