This commit is contained in:
yumoqing 2024-09-11 11:08:26 +08:00
parent 46c84f08cb
commit 141fe2fd7b
2 changed files with 23 additions and 17 deletions

View File

@ -2,35 +2,42 @@ import random
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
from aiortc import VideoStreamTrack, AudioStreamTrack from aiortc import VideoStreamTrack, AudioStreamTrack
class MyAudioTrack(AudioStreamTrack): class MyMediaPlayer(MediaPlayer):
def __init__(self): pass
class MyAudioStreamTrack(AudioStreamTrack):
def __init__(self, source=None):
super().__init__() super().__init__()
self.source = None self.source = source
def set_source(self, source): def set_source(self, source):
self.source = source.audio self.source = source
async def recv(self): async def recv(self):
print('MyAudioTrack::recv(): called') print('MyAudioTrack::recv(): called')
if self.source is None: if self.source is None:
return None return None
f = self.source.recv() f = await self.source.audio.recv()
if f is None: while f is None:
print('MyAudioTrack::recv():return None') self.set_source(MediaPlayer(self.source._file_path)
f = await self.source.audio.recv()
return f return f
class MyVideoTrack(VideoStreamTrack): class MyVideoStreamTrack(VideoStreamTrack):
def __init__(self): def __init__(self, source=None):
super().__init__() super().__init__()
self.source = None self.source = source
def set_source(self, source): def set_source(self, source):
self.source = source.video self.source = source
async def recv(self): async def recv(self):
print('MyVideoTrack::recv(): called') print('MyVideoTrack::recv(): called')
if self.source is None: if self.source is None:
return None return None
f = self.source.recv() f = await self.source.video.recv()
while f is None:
self.set_source(MediaPlayer(self.source._file_path)
f = await self.source.video.recv()
return f return f

View File

@ -22,6 +22,8 @@ import websockets
from stt import asr from stt import asr
from vad import AudioTrackVad from vad import AudioTrackVad
from a2a import LLMAudioStreamTrack from a2a import LLMAudioStreamTrack
from aav import MyMediaPlayer, MyAudioStreamTrack, MyVideoStreamTrack
from mini_omni.inference import OmniInference from mini_omni.inference import OmniInference
videos = ['./1.mp4', './2.mp4'] videos = ['./1.mp4', './2.mp4']
@ -47,9 +49,6 @@ async def pc_get_local_candidates(pc, peer):
RTCPeerConnection.get_local_candidates = pc_get_local_candidates RTCPeerConnection.get_local_candidates = pc_get_local_candidates
class MyMediaPlayer(MediaPlayer):
pass
class RTCLLM: class RTCLLM:
def __init__(self, ws_url, iceServers): def __init__(self, ws_url, iceServers):
# self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16") # self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16")
@ -127,8 +126,8 @@ class RTCLLM:
'player':player, 'player':player,
'pc':pc 'pc':pc
}) })
pc.addTrack(AudioStreamTrack(player.audio)) pc.addTrack(MyAudioStreamTrack(player))
pc.addTrack(VideoStreamTrack(player.video)) pc.addTrack(MyVideoStreamTrack(player))
await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']})) await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']}))
async def pc_track(self, peerid, track): async def pc_track(self, peerid, track):