rtcllm/rtcllm/a2a.py
2024-09-09 18:04:12 +08:00

44 lines
1.0 KiB
Python

import io
from mini_omni.inference import OmniInference
from av import AudioFrame
from aiortc import VideoStreamTrack, AudioStreamTrack
class LLMAudioStreamTrack(AudioStreamTrack):
def __init__(self, ckpt_dir):
super().__init__()
self.ckpt_dir = ckpt_dir
self.oi = OmniInference(ckpt_dir=ckpt_dir)
self.audio_iters = []
self.cur_iters = None
async def recv(self):
b = self.get_audio_bytes()
if b is None:
return b
frame = AudioFrame.from_ndarray(io.BytesIO(b), format='s16', layout='mono')
return frame
def set_cur_audio_iter(self):
if len(self.audio_iters) == 0:
return False
self.cur_iters = self.audio_iters[0]
self.audio_iters.remove(self.cur_iters)
return True
def get_audio_bytes(self):
if self.cur_iters is None:
if not set_cur_audio_iter():
return None
try:
b = next(self.cur_iters)
return b
except StopIteration:
self.cur_iters = None
return self.get_audio_bytes()
def _feed(self, audio_file):
x = self.oi.run_AT_batch_stream(audio_file)
self.audio_iters.append(x)