bugfix
This commit is contained in:
parent
9e17d711b3
commit
2476b7b76c
121
rtcllm/rtc.py
121
rtcllm/rtc.py
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
from appPublic.dictObject import DictObject
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCIceCandidate
|
||||
@ -8,65 +10,111 @@ from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, Med
|
||||
from websockets.asyncio.client import connect
|
||||
import websockets
|
||||
|
||||
from vad import AudioTrackVad
|
||||
|
||||
class RTCLLM:
|
||||
def __init__(self, ws_url, iceServers):
|
||||
self.ws_url = ws_url
|
||||
self.iceServers = iceServers
|
||||
self.peers = DictObject()
|
||||
self.info = {
|
||||
'id':'rtcllm_agent',
|
||||
'name':'rtcllm_agent'
|
||||
}
|
||||
self.dc = None
|
||||
|
||||
def get_pc(self, data):
|
||||
return self.peers[data['from'].id].pc
|
||||
|
||||
async def run(self):
|
||||
async with connect(self.ws_url) as self.ws:
|
||||
self.login()
|
||||
await self.login()
|
||||
while True:
|
||||
msg = await self.ws.recv()
|
||||
data = DictObject(**json.loads(msg))
|
||||
f = self.handler.get(data.type)
|
||||
await f(data)
|
||||
d = data.data
|
||||
print(f'ws recv(): {d.type=}')
|
||||
func = self.handlers.get(d.type)
|
||||
if func:
|
||||
f = partial(func, self)
|
||||
await f(d)
|
||||
self.ws.close()
|
||||
|
||||
async def login(self):
|
||||
await self.ws.send(json.dumps({
|
||||
'type':'login',
|
||||
'info':{
|
||||
'id':'rtcllm_agent',
|
||||
'name':'rtcllm_agent'
|
||||
}
|
||||
'info':self.info
|
||||
}))
|
||||
|
||||
async def on_icecandidate(self, pc, candidate):
|
||||
print('on_icecandidate()', self, pc, candidate)
|
||||
if candidate:
|
||||
await self.ws.send(json.dumps({
|
||||
"type":"candidate",
|
||||
"candidate":candidate,
|
||||
}))
|
||||
|
||||
async def save_onlineList(self, data):
|
||||
print(f'{self}, {type(self)}')
|
||||
self.onlineList = data.onlineList
|
||||
|
||||
async def vad_voiceend(self, fn):
|
||||
print(f'vad_voiceend():{fn=}')
|
||||
|
||||
async def auto_accept_call(self, data):
|
||||
pc = await self.createPeerConnection(data['from'])
|
||||
opts = DictObject(iceServers=self.iceServers)
|
||||
pc = RTCPeerConnection(opts)
|
||||
self.peers[data['from'].id] = DictObject(**{
|
||||
'info':data['from'],
|
||||
'pc':pc
|
||||
})
|
||||
print(f'{opts=}, {pc=}, {pc.myconfiguration=}')
|
||||
await self.ws.send(json.dumps({'type':'callAccepted', 'to':data['from']}))
|
||||
|
||||
async def pc_track(self, peerid, track):
|
||||
peer = self.peers[peerid]
|
||||
pc = peer.pc
|
||||
if track.kind == 'audio':
|
||||
vadtrack = AudioTrackVad(track, onvoiceend=self.vad_voiceend)
|
||||
peer.vadtrack = vadtrack
|
||||
vadtrack.start_vad()
|
||||
|
||||
async def pc_connectionState_changed(self, peerid):
|
||||
peer = self.peers[peerid]
|
||||
pc = peer.pc
|
||||
print(f'************************************{pc.connectionState=}')
|
||||
if pc.connectionState == 'connected':
|
||||
peer.dc = await pc.createDataChannel(peer.info.name)
|
||||
return
|
||||
if pc.connectionState == 'closed':
|
||||
await pc.close()
|
||||
if peer.dc:
|
||||
await peer.dc.close()
|
||||
|
||||
peers = {
|
||||
k:v for k,v in self.peers.items() if k != peerid
|
||||
}
|
||||
self.peers = peers
|
||||
if len([k for k in self.peers.keys()]) == 0:
|
||||
await self.ws.close()
|
||||
|
||||
async def response_offer(self, data):
|
||||
pc = self.get_pc(data)
|
||||
offer = RTCSessionDescription(data.offer)
|
||||
if pc is None:
|
||||
print(f'{self.peers=}, {data=}')
|
||||
return
|
||||
pc.on("connectionState", partial(self.pc_connectionState_changed, data['from'].id))
|
||||
pc.on('track', partial(self.pc_track, data['from'].id))
|
||||
pc.on('icecandidate', partial(self.on_icecandidate, pc))
|
||||
offer = RTCSessionDescription(** data.offer)
|
||||
await pc.setRemoteDescription(offer)
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
self.ws.send(json.dumps({
|
||||
await self.ws.send(json.dumps({
|
||||
'type':'answer',
|
||||
'answer':pc.localDescription,
|
||||
'answer':{'type':pc.localDescription.type, 'sdp':pc.localDescription.sdp},
|
||||
'to':data['from']
|
||||
}))
|
||||
@pc.on("datachannel")
|
||||
def datachannel_handle(channel):
|
||||
@channel.on("message")
|
||||
def recvdata(channel):
|
||||
pass
|
||||
"""
|
||||
@pc.on("connectionstatechange")
|
||||
@pc.on("track")
|
||||
"""
|
||||
|
||||
async def accept_answer(self, data):
|
||||
pc = self.get_pc(data)
|
||||
@ -75,15 +123,27 @@ class RTCLLM:
|
||||
|
||||
async def accept_iceCandidate(self, data):
|
||||
pc = self.get_pc(data)
|
||||
candidate = RTCIceCandidate(data.candidate)
|
||||
# pc.addIceCandidate(RTCIceCandidate.from_string(candidate))
|
||||
await pc.addIceCandidate(candidate)
|
||||
|
||||
async def createPeerConnection(self, peerinfo):
|
||||
opts = {
|
||||
iceServers:self.iceServers
|
||||
}
|
||||
pc = RTCPeerConnection(opts)
|
||||
candidate = data.candidate
|
||||
ip = candidate['candidate'].split(' ')[4]
|
||||
port = candidate['candidate'].split(' ')[5]
|
||||
protocol = candidate['candidate'].split(' ')[7]
|
||||
priority = candidate['candidate'].split(' ')[3]
|
||||
foundation = candidate['candidate'].split(' ')[0]
|
||||
component = candidate['candidate'].split(' ')[1]
|
||||
type = candidate['candidate'].split(' ')[7]
|
||||
rtc_candidate = RTCIceCandidate(
|
||||
ip=ip,
|
||||
port=port,
|
||||
protocol=protocol,
|
||||
priority=priority,
|
||||
foundation=foundation,
|
||||
component=component,
|
||||
type=type,
|
||||
sdpMid=candidate['sdpMid'],
|
||||
sdpMLineIndex=candidate['sdpMLineIndex']
|
||||
)
|
||||
await pc.addIceCandidate(rtc_candidate)
|
||||
print('addIceCandidate ok')
|
||||
|
||||
handlers = {
|
||||
'onlineList':save_onlineList,
|
||||
@ -94,13 +154,14 @@ class RTCLLM:
|
||||
}
|
||||
|
||||
async def main():
|
||||
agent = RTCLLM(ws_url='https://sage.opencomputing.cn/wss/ws/rtc_signaling.ws',
|
||||
agent = RTCLLM(ws_url='wss://sage.opencomputing.cn/wss/ws/rtc_signaling.ws',
|
||||
iceServers=[{
|
||||
'urls':'stun:stun.opencomputing.cn'},{
|
||||
'urls':'turn:stun.opencomputing.cn',
|
||||
'username':'turn',
|
||||
'credential':'server'
|
||||
}])
|
||||
print('running ...')
|
||||
await agent.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
from aiortc import MediaStreamTrack
|
||||
@ -9,20 +10,40 @@ class AudioTrackVad(MediaStreamTrack):
|
||||
def __init__(self, track, stage=3, onvoiceend=None):
|
||||
super().__init__()
|
||||
self.track = track
|
||||
print(dir(track), 'AudioTrackVad.__init__()')
|
||||
self.onvoiceend = onvoiceend
|
||||
self.vad = webrtcvad.Vad(stage)
|
||||
self.sample_rate = self.track.getSettings().sampleRate
|
||||
frameSize = self.track.getSettings().frameSize
|
||||
self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
|
||||
# self.sample_rate = self.track.getSettings().sampleRate
|
||||
# frameSize = self.track.getSettings().frameSize
|
||||
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
|
||||
self.frame_duration_ms = 0.00008
|
||||
self.num_padding_frames = 10
|
||||
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
||||
self.triggered = False
|
||||
self.voiced_frames = []
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.task = None
|
||||
self.running = False
|
||||
|
||||
def start_vad(self):
|
||||
self.running = True
|
||||
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
|
||||
|
||||
def _recv(self):
|
||||
asyncio.create_task(self.recv())
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
if self.running:
|
||||
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
async def recv(self):
|
||||
frame = await self.track.recv()
|
||||
self.vad_check(frame)
|
||||
return frame
|
||||
f = await self.track.recv()
|
||||
print(f'{f.pts=}, {f.rate=}, {f.sample_rate=}, {f.format=}, {f.dts=}, {f.samples=}')
|
||||
self.vad_check(f)
|
||||
return f
|
||||
|
||||
async def vad_check(self, frame):
|
||||
is_speech = self.vad.is_speech(frame, self.sample_rate)
|
||||
@ -69,5 +90,5 @@ class AudioTrackVad(MediaStreamTrack):
|
||||
wf.writeframes(audio)
|
||||
|
||||
if self.onvoiceend:
|
||||
self.onvoiceend(path)
|
||||
await self.onvoiceend(path)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user