This commit is contained in:
yumoqing 2024-08-28 18:26:47 +08:00
parent 9e17d711b3
commit 2476b7b76c
2 changed files with 119 additions and 37 deletions

View File

@ -1,4 +1,6 @@
import asyncio import asyncio
import json
from functools import partial
from appPublic.dictObject import DictObject from appPublic.dictObject import DictObject
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCIceCandidate from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCIceCandidate
@ -8,65 +10,111 @@ from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, Med
from websockets.asyncio.client import connect from websockets.asyncio.client import connect
import websockets import websockets
from vad import AudioTrackVad
class RTCLLM: class RTCLLM:
def __init__(self, ws_url, iceServers): def __init__(self, ws_url, iceServers):
self.ws_url = ws_url self.ws_url = ws_url
self.iceServers = iceServers self.iceServers = iceServers
self.peers = DictObject() self.peers = DictObject()
self.info = {
'id':'rtcllm_agent',
'name':'rtcllm_agent'
}
self.dc = None
def get_pc(self, data): def get_pc(self, data):
return self.peers[data['from'].id].pc return self.peers[data['from'].id].pc
async def run(self): async def run(self):
async with connect(self.ws_url) as self.ws: async with connect(self.ws_url) as self.ws:
self.login() await self.login()
while True: while True:
msg = await self.ws.recv() msg = await self.ws.recv()
data = DictObject(**json.loads(msg)) data = DictObject(**json.loads(msg))
f = self.handler.get(data.type) d = data.data
await f(data) print(f'ws recv(): {d.type=}')
func = self.handlers.get(d.type)
if func:
f = partial(func, self)
await f(d)
self.ws.close() self.ws.close()
async def login(self): async def login(self):
await self.ws.send(json.dumps({ await self.ws.send(json.dumps({
'type':'login', 'type':'login',
'info':{ 'info':self.info
'id':'rtcllm_agent',
'name':'rtcllm_agent'
}
})) }))
async def on_icecandidate(self, pc, candidate):
print('on_icecandidate()', self, pc, candidate)
if candidate:
await self.ws.send(json.dumps({
"type":"candidate",
"candidate":candidate,
}))
async def save_onlineList(self, data): async def save_onlineList(self, data):
print(f'{self}, {type(self)}')
self.onlineList = data.onlineList self.onlineList = data.onlineList
async def vad_voiceend(self, fn):
print(f'vad_voiceend():{fn=}')
async def auto_accept_call(self, data): async def auto_accept_call(self, data):
pc = await self.createPeerConnection(data['from']) opts = DictObject(iceServers=self.iceServers)
pc = RTCPeerConnection(opts)
self.peers[data['from'].id] = DictObject(**{ self.peers[data['from'].id] = DictObject(**{
'info':data['from'], 'info':data['from'],
'pc':pc 'pc':pc
}) })
print(f'{opts=}, {pc=}, {pc.myconfiguration=}')
await self.ws.send(json.dumps({'type':'callAccepted', 'to':data['from']})) await self.ws.send(json.dumps({'type':'callAccepted', 'to':data['from']}))
async def pc_track(self, peerid, track):
peer = self.peers[peerid]
pc = peer.pc
if track.kind == 'audio':
vadtrack = AudioTrackVad(track, onvoiceend=self.vad_voiceend)
peer.vadtrack = vadtrack
vadtrack.start_vad()
async def pc_connectionState_changed(self, peerid):
peer = self.peers[peerid]
pc = peer.pc
print(f'************************************{pc.connectionState=}')
if pc.connectionState == 'connected':
peer.dc = await pc.createDataChannel(peer.info.name)
return
if pc.connectionState == 'closed':
await pc.close()
if peer.dc:
await peer.dc.close()
peers = {
k:v for k,v in self.peers.items() if k != peerid
}
self.peers = peers
if len([k for k in self.peers.keys()]) == 0:
await self.ws.close()
async def response_offer(self, data): async def response_offer(self, data):
pc = self.get_pc(data) pc = self.get_pc(data)
offer = RTCSessionDescription(data.offer) if pc is None:
print(f'{self.peers=}, {data=}')
return
pc.on("connectionState", partial(self.pc_connectionState_changed, data['from'].id))
pc.on('track', partial(self.pc_track, data['from'].id))
pc.on('icecandidate', partial(self.on_icecandidate, pc))
offer = RTCSessionDescription(** data.offer)
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)
answer = await pc.createAnswer() answer = await pc.createAnswer()
await pc.setLocalDescription(answer) await pc.setLocalDescription(answer)
self.ws.send(json.dumps({ await self.ws.send(json.dumps({
'type':'answer', 'type':'answer',
'answer':pc.localDescription, 'answer':{'type':pc.localDescription.type, 'sdp':pc.localDescription.sdp},
'to':data['from'] 'to':data['from']
})) }))
@pc.on("datachannel")
def datachannel_handle(channel):
@channel.on("message")
def recvdata(channel):
pass
"""
@pc.on("connectionstatechange")
@pc.on("track")
"""
async def accept_answer(self, data): async def accept_answer(self, data):
pc = self.get_pc(data) pc = self.get_pc(data)
@ -75,15 +123,27 @@ class RTCLLM:
async def accept_iceCandidate(self, data): async def accept_iceCandidate(self, data):
pc = self.get_pc(data) pc = self.get_pc(data)
candidate = RTCIceCandidate(data.candidate) candidate = data.candidate
# pc.addIceCandidate(RTCIceCandidate.from_string(candidate)) ip = candidate['candidate'].split(' ')[4]
await pc.addIceCandidate(candidate) port = candidate['candidate'].split(' ')[5]
protocol = candidate['candidate'].split(' ')[7]
async def createPeerConnection(self, peerinfo): priority = candidate['candidate'].split(' ')[3]
opts = { foundation = candidate['candidate'].split(' ')[0]
iceServers:self.iceServers component = candidate['candidate'].split(' ')[1]
} type = candidate['candidate'].split(' ')[7]
pc = RTCPeerConnection(opts) rtc_candidate = RTCIceCandidate(
ip=ip,
port=port,
protocol=protocol,
priority=priority,
foundation=foundation,
component=component,
type=type,
sdpMid=candidate['sdpMid'],
sdpMLineIndex=candidate['sdpMLineIndex']
)
await pc.addIceCandidate(rtc_candidate)
print('addIceCandidate ok')
handlers = { handlers = {
'onlineList':save_onlineList, 'onlineList':save_onlineList,
@ -94,13 +154,14 @@ class RTCLLM:
} }
async def main(): async def main():
agent = RTCLLM(ws_url='https://sage.opencomputing.cn/wss/ws/rtc_signaling.ws', agent = RTCLLM(ws_url='wss://sage.opencomputing.cn/wss/ws/rtc_signaling.ws',
iceServers=[{ iceServers=[{
'urls':'stun:stun.opencomputing.cn'},{ 'urls':'stun:stun.opencomputing.cn'},{
'urls':'turn:stun.opencomputing.cn', 'urls':'turn:stun.opencomputing.cn',
'username':'turn', 'username':'turn',
'credential':'server' 'credential':'server'
}]) }])
print('running ...')
await agent.run() await agent.run()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,3 +1,4 @@
import asyncio
import collections import collections
import contextlib import contextlib
from aiortc import MediaStreamTrack from aiortc import MediaStreamTrack
@ -9,20 +10,40 @@ class AudioTrackVad(MediaStreamTrack):
def __init__(self, track, stage=3, onvoiceend=None): def __init__(self, track, stage=3, onvoiceend=None):
super().__init__() super().__init__()
self.track = track self.track = track
print(dir(track), 'AudioTrackVad.__init__()')
self.onvoiceend = onvoiceend self.onvoiceend = onvoiceend
self.vad = webrtcvad.Vad(stage) self.vad = webrtcvad.Vad(stage)
self.sample_rate = self.track.getSettings().sampleRate # self.sample_rate = self.track.getSettings().sampleRate
frameSize = self.track.getSettings().frameSize # frameSize = self.track.getSettings().frameSize
self.frame_duration_ms = (1000 * frameSize) / self.sample_rate # self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
self.frame_duration_ms = 0.00008
self.num_padding_frames = 10 self.num_padding_frames = 10
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False self.triggered = False
self.voiced_frames = [] self.voiced_frames = []
self.loop = asyncio.get_event_loop()
self.task = None
self.running = False
def start_vad(self):
self.running = True
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
def _recv(self):
asyncio.create_task(self.recv())
if self.task:
self.task.cancel()
if self.running:
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
def stop(self):
self.running = False
async def recv(self): async def recv(self):
frame = await self.track.recv() f = await self.track.recv()
self.vad_check(frame) print(f'{f.pts=}, {f.rate=}, {f.sample_rate=}, {f.format=}, {f.dts=}, {f.samples=}')
return frame self.vad_check(f)
return f
async def vad_check(self, frame): async def vad_check(self, frame):
is_speech = self.vad.is_speech(frame, self.sample_rate) is_speech = self.vad.is_speech(frame, self.sample_rate)
@ -69,5 +90,5 @@ class AudioTrackVad(MediaStreamTrack):
wf.writeframes(audio) wf.writeframes(audio)
if self.onvoiceend: if self.onvoiceend:
self.onvoiceend(path) await self.onvoiceend(path)