This commit is contained in:
yumoqing 2024-09-11 11:08:26 +08:00
parent 46c84f08cb
commit 141fe2fd7b
2 changed files with 23 additions and 17 deletions

View File

@ -2,35 +2,42 @@ import random
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
from aiortc import VideoStreamTrack, AudioStreamTrack
class MyAudioTrack(AudioStreamTrack):
def __init__(self):
class MyMediaPlayer(MediaPlayer):
pass
class MyAudioStreamTrack(AudioStreamTrack):
def __init__(self, source=None):
super().__init__()
self.source = None
self.source = source
def set_source(self, source):
self.source = source.audio
self.source = source
async def recv(self):
print('MyAudioTrack::recv(): called')
if self.source is None:
return None
f = self.source.recv()
if f is None:
print('MyAudioTrack::recv():return None')
f = await self.source.audio.recv()
while f is None:
self.set_source(MediaPlayer(self.source._file_path)
f = await self.source.audio.recv()
return f
class MyVideoTrack(VideoStreamTrack):
def __init__(self):
class MyVideoStreamTrack(VideoStreamTrack):
def __init__(self, source=None):
super().__init__()
self.source = None
self.source = source
def set_source(self, source):
self.source = source.video
self.source = source
async def recv(self):
print('MyVideoTrack::recv(): called')
if self.source is None:
return None
f = self.source.recv()
f = await self.source.video.recv()
while f is None:
self.set_source(MediaPlayer(self.source._file_path)
f = await self.source.video.recv()
return f

View File

@ -22,6 +22,8 @@ import websockets
from stt import asr
from vad import AudioTrackVad
from a2a import LLMAudioStreamTrack
from aav import MyMediaPlayer, MyAudioStreamTrack, MyVideoStreamTrack
from mini_omni.inference import OmniInference
videos = ['./1.mp4', './2.mp4']
@ -47,9 +49,6 @@ async def pc_get_local_candidates(pc, peer):
RTCPeerConnection.get_local_candidates = pc_get_local_candidates
class MyMediaPlayer(MediaPlayer):
pass
class RTCLLM:
def __init__(self, ws_url, iceServers):
# self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16")
@ -127,8 +126,8 @@ class RTCLLM:
'player':player,
'pc':pc
})
pc.addTrack(AudioStreamTrack(player.audio))
pc.addTrack(VideoStreamTrack(player.video))
pc.addTrack(MyAudioStreamTrack(player))
pc.addTrack(MyVideoStreamTrack(player))
await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']}))
async def pc_track(self, peerid, track):