diff --git a/rtcllm/aav.py b/rtcllm/aav.py index 99227cb..e417872 100644 --- a/rtcllm/aav.py +++ b/rtcllm/aav.py @@ -2,35 +2,42 @@ import random from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay from aiortc import VideoStreamTrack, AudioStreamTrack -class MyAudioTrack(AudioStreamTrack): - def __init__(self): +class MyMediaPlayer(MediaPlayer): + pass + +class MyAudioStreamTrack(AudioStreamTrack): + def __init__(self, source=None): super().__init__() - self.source = None + self.source = source def set_source(self, source): - self.source = source.audio + self.source = source async def recv(self): print('MyAudioTrack::recv(): called') if self.source is None: return None - f = self.source.recv() - if f is None: - print('MyAudioTrack::recv():return None') + f = await self.source.audio.recv() + while f is None: + self.set_source(MediaPlayer(self.source._file_path) + f = await self.source.audio.recv() return f -class MyVideoTrack(VideoStreamTrack): - def __init__(self): +class MyVideoStreamTrack(VideoStreamTrack): + def __init__(self, source=None): super().__init__() - self.source = None + self.source = source def set_source(self, source): - self.source = source.video + self.source = source async def recv(self): print('MyVideoTrack::recv(): called') if self.source is None: return None - f = self.source.recv() + f = await self.source.video.recv() + while f is None: + self.set_source(MediaPlayer(self.source._file_path) + f = await self.source.video.recv() return f diff --git a/rtcllm/rtc.py b/rtcllm/rtc.py index 74d5115..e7d1c96 100644 --- a/rtcllm/rtc.py +++ b/rtcllm/rtc.py @@ -22,6 +22,8 @@ import websockets from stt import asr from vad import AudioTrackVad from a2a import LLMAudioStreamTrack +from aav import MyMediaPlayer, MyAudioStreamTrack, MyVideoStreamTrack + from mini_omni.inference import OmniInference videos = ['./1.mp4', './2.mp4'] @@ -47,9 +49,6 @@ async def pc_get_local_candidates(pc, peer): RTCPeerConnection.get_local_candidates = pc_get_local_candidates -class MyMediaPlayer(MediaPlayer): - pass - class RTCLLM: def __init__(self, ws_url, iceServers): # self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16") @@ -127,8 +126,8 @@ class RTCLLM: 'player':player, 'pc':pc }) - pc.addTrack(AudioStreamTrack(player.audio)) - pc.addTrack(VideoStreamTrack(player.video)) + pc.addTrack(MyAudioStreamTrack(player)) + pc.addTrack(MyVideoStreamTrack(player)) await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']})) async def pc_track(self, peerid, track):