diff --git a/rtcllm/a2a.py b/rtcllm/a2a.py index 6bf836c..71169ec 100644 --- a/rtcllm/a2a.py +++ b/rtcllm/a2a.py @@ -1,18 +1,19 @@ import io -from mini_omni.inference import OmniInference from av import AudioFrame from aiortc import VideoStreamTrack, AudioStreamTrack class LLMAudioStreamTrack(AudioStreamTrack): - def __init__(self, ckpt_dir): + def __init__(self, omni_infer): super().__init__() self.ckpt_dir = ckpt_dir - self.oi = OmniInference(ckpt_dir=ckpt_dir) + # self.oi = OmniInference(ckpt_dir=ckpt_dir) + self.oi = omni_infer self.audio_iters = [] self.cur_iters = None async def recv(self): + print(f'LLMAudioStreamTrack():recv() called ....') b = self.get_audio_bytes() if b is None: return b @@ -27,17 +28,18 @@ class LLMAudioStreamTrack(AudioStreamTrack): return True def get_audio_bytes(self): - if self.cur_iters is None: - if not set_cur_audio_iter(): - return None - try: - b = next(self.cur_iters) - return b - except StopIteration: - self.cur_iters = None - return self.get_audio_bytes() + if self.cur_iters is None: + if not set_cur_audio_iter(): + return None + try: + b = next(self.cur_iters) + return b + except StopIteration: + self.cur_iters = None + return self.get_audio_bytes() def _feed(self, audio_file): x = self.oi.run_AT_batch_stream(audio_file) self.audio_iters.append(x) + return x diff --git a/rtcllm/rtc.py b/rtcllm/rtc.py index a96227b..1f96313 100644 --- a/rtcllm/rtc.py +++ b/rtcllm/rtc.py @@ -50,8 +50,8 @@ class RTCLLM: def __init__(self, ws_url, iceServers): # self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16") self.ws_url = ws_url - self.llmtrack = LLMAudioStreamTrack('/d/models/mini-omni') - self.feed = awaitify(self.llmtrack._feed) + self.omni_infer = + self.omni_infer = OmniInference(ckpt_dir='/d/models/mini-omni') self.iceServers = iceServers self.peers = DictObject() self.vid = 0 @@ -108,15 +108,19 @@ class RTCLLM: async def vad_voiceend(self, peer, audio): if audio is not None: - ret = await self.feed(audio) + feed = awaitify(peer.llmtrack._feed) + ret = await feed(audio) print(f'self.feed("{audio}") return {ret}') os.remove(audio) async def auto_accept_call(self, data): opts = DictObject(iceServers=self.iceServers) pc = RTCPeerConnection(opts) + llmtrack = LLMAudioStreamTrack('/d/models/mini-omni') + feed = awaitify(self.llmtrack._feed) self.peers[data['from'].id] = DictObject(**{ 'info':data['from'], + 'llmtrack':llmtrack, 'pc':pc }) pc.addTrack(self.llmtrack)