From c225d69aba946d364037ccc7355289b304d8039a Mon Sep 17 00:00:00 2001 From: yumoqing Date: Mon, 2 Sep 2024 18:43:58 +0800 Subject: [PATCH] bugfix --- rtcllm/aav.py | 36 ++++++++++++++++++++++++++++++++++++ rtcllm/rtc.py | 50 +++++++++++++++++++++++++++++++++++++++++--------- rtcllm/stt.py | 37 +++++++++++++++++++++++++++++++++++++ rtcllm/vad.py | 13 ++++++++++--- 4 files changed, 124 insertions(+), 12 deletions(-) create mode 100644 rtcllm/aav.py create mode 100644 rtcllm/stt.py diff --git a/rtcllm/aav.py b/rtcllm/aav.py new file mode 100644 index 0000000..99227cb --- /dev/null +++ b/rtcllm/aav.py @@ -0,0 +1,36 @@ +import random +from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay +from aiortc import VideoStreamTrack, AudioStreamTrack + +class MyAudioTrack(AudioStreamTrack): + def __init__(self): + super().__init__() + self.source = None + + def set_source(self, source): + self.source = source.audio + + async def recv(self): + print('MyAudioTrack::recv(): called') + + if self.source is None: + return None + f = self.source.recv() + if f is None: + print('MyAudioTrack::recv():return None') + return f + +class MyVideoTrack(VideoStreamTrack): + def __init__(self): + super().__init__() + self.source = None + + def set_source(self, source): + self.source = source.video + + async def recv(self): + print('MyVideoTrack::recv(): called') + if self.source is None: + return None + f = self.source.recv() + return f diff --git a/rtcllm/rtc.py b/rtcllm/rtc.py index 5b31062..4df14e7 100644 --- a/rtcllm/rtc.py +++ b/rtcllm/rtc.py @@ -1,4 +1,5 @@ import asyncio +import random import json from functools import partial @@ -9,19 +10,24 @@ from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, Med # from websockets.asyncio.client import connect from websockets.asyncio.client import connect import websockets - +from stt import asr from vad import AudioTrackVad +from aav import MyAudioTrack, MyVideoTrack + +videos = ['./1.mp4', './2.mp4'] class RTCLLM: def __init__(self, ws_url, iceServers): self.ws_url = ws_url self.iceServers = iceServers self.peers = DictObject() + self.vid = 0 self.info = { 'id':'rtcllm_agent', 'name':'rtcllm_agent' } self.dc = None + self.loop = asyncio.get_event_loop() def get_pc(self, data): return self.peers[data['from'].id].pc @@ -58,8 +64,10 @@ class RTCLLM: print(f'{self}, {type(self)}') self.onlineList = data.onlineList - async def vad_voiceend(self, fn): - print(f'vad_voiceend():{fn=}') + async def vad_voiceend(self, peer, audio): + txt = await asr(audio) + # await peer.dc.send(txt) + async def auto_accept_call(self, data): opts = DictObject(iceServers=self.iceServers) @@ -74,14 +82,27 @@ class RTCLLM: peer = self.peers[peerid] pc = peer.pc if track.kind == 'audio': - vadtrack = AudioTrackVad(track, stage=3, onvoiceend=self.vad_voiceend) + f = partial(self.vad_voiceend, peer) + vadtrack = AudioTrackVad(track, stage=3, onvoiceend=f) peer.vadtrack = vadtrack vadtrack.start_vad() + def play_random_media(self, vt, at): + i = random.randint(0,1) + player = MediaPlayer(videos[i]) + vt.set_source(player) + at.set_source(player) + f = partial(self.play_random_media, vt, at) + self.loop.call_later(180, f) + async def pc_connectionState_changed(self, peerid): peer = self.peers[peerid] pc = peer.pc - print(f'************************************{pc.connectionState=}') + peer.audiotrack = MyAudioTrack() + peer.videotrack = MyVideoTrack() + pc.addTrack(peer.audiotrack) + pc.addTrack(peer.videotrack) + self.play_random_media(peer.videotrack, peer.audiotrack) if pc.connectionState == 'connected': peer.dc = pc.createDataChannel(peer.info.name) return @@ -97,6 +118,20 @@ class RTCLLM: if len([k for k in self.peers.keys()]) == 0: await self.ws.close() + def play_video(self, peerid): + print('play video ........................') + pc = self.peers[peerid].pc + """ + player = MediaPlayer(videos[0]) + if player: + pc.addTrack(player.audio) + pc.addTrack(player.video) + """ + player = MediaPlayer(videos[1]) + if player: + pc.addTrack(player.audio) + pc.addTrack(player.video) + async def response_offer(self, data): pc = self.get_pc(data) if pc is None: @@ -105,10 +140,7 @@ class RTCLLM: pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id)) pc.on('track', partial(self.pc_track, data['from'].id)) pc.on('icecandidate', partial(self.on_icecandidate, pc)) - player = MediaPlayer('./1.mp4') - if player: - pc.addTrack(player.audio) - pc.addTrack(player.video) + # self.play_video(data) offer = RTCSessionDescription(** data.offer) await pc.setRemoteDescription(offer) diff --git a/rtcllm/stt.py b/rtcllm/stt.py new file mode 100644 index 0000000..c306e51 --- /dev/null +++ b/rtcllm/stt.py @@ -0,0 +1,37 @@ +from appPublic.dictObject import DictObject +from appPublic.oauth_client import OAuthClient + +desc = { + "path":"/asr/generate", + "method":"POST", + "headers":[ + { + "name":"Content-Type", + "value":"application/json" + } + ], + "data":[{ + "name":"audio", + "value":"${b64audio}" + },{ + "name":"model", + "value":"whisper" + } + ], + "resp":[ + { + "name":"content", + "value":"content" + } + ] +} +opts = { + "data":{}, + "asr":desc +} + +async def asr(audio): + oc = OAuthClient(DictObject(**opts)) + r = await oc("http://open-computing.cn", "asr", {"b64audio":audio}) + print(f'{r=}') + diff --git a/rtcllm/vad.py b/rtcllm/vad.py index dd60c0c..d6abf22 100644 --- a/rtcllm/vad.py +++ b/rtcllm/vad.py @@ -1,3 +1,4 @@ +import base64 from traceback import print_exc import asyncio import collections @@ -20,8 +21,8 @@ class AudioTrackVad(MediaStreamTrack): # self.sample_rate = self.track.getSettings().sampleRate # frameSize = self.track.getSettings().frameSize # self.frame_duration_ms = (1000 * frameSize) / self.sample_rate - self.frame_duration_ms = 0.00008 - self.num_padding_frames = 20 + self.frame_duration_ms = 0.00002 + self.num_padding_frames = 10 self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.triggered = False self.voiced_frames = [] @@ -104,10 +105,16 @@ class AudioTrackVad(MediaStreamTrack): if num_unvoiced > 0.9 * self.ring_buffer.maxlen: self.triggered = False audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames]) - await self.write_wave(audio_data) + # await self.write_wave(audio_data) + await self.gen_base64(audio_data) self.ring_buffer.clear() self.voiced_frames = [] + print('end voice .....', len(self.voiced_frames)) + async def gen_base64(self, audio_data): + b64 = base64.b64encode(audio_data).decode('utf-8') + if self.onvoiceend: + await self.onvoiceend(b64) async def write_wave(self, audio_data): """Writes a .wav file.