44 lines
1.0 KiB
Python
44 lines
1.0 KiB
Python
import io
|
|
from mini_omni.inference import OmniInference
|
|
from av import AudioFrame
|
|
|
|
from aiortc import VideoStreamTrack, AudioStreamTrack
|
|
|
|
class LLMAudioStreamTrack(AudioStreamTrack):
|
|
def __init__(self, ckpt_dir):
|
|
super().__init__()
|
|
self.ckpt_dir = ckpt_dir
|
|
self.oi = OmniInference(ckpt_dir=ckpt_dir)
|
|
self.audio_iters = []
|
|
self.cur_iters = None
|
|
|
|
async def recv(self):
|
|
b = self.get_audio_bytes()
|
|
if b is None:
|
|
return b
|
|
frame = AudioFrame.from_ndarray(io.BytesIO(b), format='s16', layout='mono')
|
|
return frame
|
|
|
|
def set_cur_audio_iter(self):
|
|
if len(self.audio_iters) == 0:
|
|
return False
|
|
self.cur_iters = self.audio_iters[0]
|
|
self.audio_iters.remove(self.cur_iters)
|
|
return True
|
|
|
|
def get_audio_bytes(self):
|
|
if self.cur_iters is None:
|
|
if not set_cur_audio_iter():
|
|
return None
|
|
try:
|
|
b = next(self.cur_iters)
|
|
return b
|
|
except StopIteration:
|
|
self.cur_iters = None
|
|
return self.get_audio_bytes()
|
|
|
|
def _feed(self, audio_file):
|
|
x = self.oi.run_AT_batch_stream(audio_file)
|
|
self.audio_iters.append(x)
|
|
|