diff --git a/rtcllm/a2a.py b/rtcllm/a2a.py new file mode 100644 index 0000000..0315f31 --- /dev/null +++ b/rtcllm/a2a.py @@ -0,0 +1,41 @@ +import io +from mini_omni.inference import OmniInference +from av import AudioFrame + +class LLMAudioStreamTrack(AudioStreamTrack): + def __init__(self, ckpt_dir): + super().__init__() + self.ckpt_dir = ckpt_dir + self.oi = OmniInference(ckpt_dir=chpt_dir) + self.audio_iters = [] + self.cur_iters = None + + async def recv(self): + b = self.get_audio_bytes() + if b is None: + return b + frame = AudioFrame.from_ndarray(io.BytesIO(b), format='s16', layout='mono') + return frame + + def set_cur_audio_iter(self): + if len(self.audio_iters) == 0: + return False + self.cur_iters = self.audio_iters[0] + self.audio_iters.remove(self.cur_iters) + rteturn True + + def get_audio_bytes(self): + if self.cur_iters is None: + if not set_cur_audio_iter(): + return None + try: + b = next(self.cur_iters) + return b + except StopIteration: + self.cur_iters = None + return self.get_audio_bytes() + + def _feed(self, audio_file): + x = self.oi.run_AT_batch_stream(audio_file) + self.audio_iters.append(x) + diff --git a/rtcllm/rtc.py b/rtcllm/rtc.py index 1333f8f..5c8fccb 100644 --- a/rtcllm/rtc.py +++ b/rtcllm/rtc.py @@ -16,7 +16,7 @@ from websockets.client import connect import websockets from stt import asr from vad import AudioTrackVad -from aav import MyAudioTrack, MyVideoTrack +from a2a import LLMAudioStreamTrack videos = ['./1.mp4', './2.mp4'] @@ -45,6 +45,8 @@ class RTCLLM: def __init__(self, ws_url, iceServers): # self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16") self.ws_url = ws_url + self.llmtrack = LLMAudioStreamTrack('/d/models/mini-omni') + self.feed = awaitify(self.llmtrack._feed) self.iceServers = iceServers self.peers = DictObject() self.vid = 0 @@ -100,9 +102,7 @@ class RTCLLM: self.onlineList = data.onlineList async def vad_voiceend(self, peer, audio): - ret = await asr(audio) - # await peer.dc.send(txt) - + ret = await self.feed(audio) async def auto_accept_call(self, data): opts = DictObject(iceServers=self.iceServers) @@ -111,6 +111,7 @@ class RTCLLM: 'info':data['from'], 'pc':pc }) + pc.addTrack(self.llmtrack) await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']})) async def pc_track(self, peerid, track): @@ -122,14 +123,6 @@ class RTCLLM: peer.vadtrack = vadtrack vadtrack.start_vad() - def play_random_media(self, vt, at): - i = random.randint(0,1) - player = MediaPlayer(videos[i]) - vt.set_source(player) - at.set_source(player) - f = partial(self.play_random_media, vt, at) - self.loop.call_later(180, f) - async def pc_connectionState_changed(self, peerid): peer = self.peers[peerid] pc = peer.pc @@ -159,11 +152,6 @@ class RTCLLM: pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id)) pc.on('track', partial(self.pc_track, data['from'].id)) pc.on('icecandidate', partial(self.on_icecandidate, pc, data['from'])) - peer.audiotrack = MyAudioTrack() - peer.videotrack = MyVideoTrack() - pc.addTrack(peer.audiotrack) - pc.addTrack(peer.videotrack) - self.play_random_media(peer.videotrack, peer.audiotrack) offer = RTCSessionDescription(** data.offer) await pc.setRemoteDescription(offer)