rtcllm/rtcllm/a2a.py
2024-09-11 15:51:19 +08:00

61 lines
1.5 KiB
Python

import os
from traceback import print_exc
import io
import asyncio
from av import AudioFrame
from aiortc import VideoStreamTrack, AudioStreamTrack
class LLMAudioStreamTrack(AudioStreamTrack):
def __init__(self, omni_infer):
super().__init__()
self.oi = omni_infer
self.audio_iters = []
self.cur_iters = None
self.tmp_files = []
async def recv(self):
try:
b = self.get_audio_bytes()
if b is None:
return await super().recv()
frame = AudioFrame.from_ndarray(io.BytesIO(b), format='s16', layout='mono')
print('LLMAudioStreamTrack return frame ...')
return frame
except Exception as e:
print_exc()
print(f'{self.__class__.__name__} recv() exception happened')
return None
def set_cur_audio_iter(self):
if len(self.audio_iters) == 0:
return False
self.cur_iters = self.audio_iters[0]
self.audio_iters.remove(self.cur_iters)
return True
def get_audio_bytes(self):
if self.cur_iters is None:
if not self.set_cur_audio_iter():
return None
try:
b = next(self.cur_iters)
return b
except StopIteration:
self.cur_iters = None
if len(self.tmp_files) > 0:
tf = self.tmp_files[0]
self.tmp_files.remove(tf)
os.remove(tf)
return self.get_audio_bytes()
def _feed(self, audio_file):
self.tmp_files.append(audio_file)
if audio_file is None:
print(f'*****{self.__class__.__name__}._feed(),{audio_file=}')
return
x = self.oi.run_AT_batch_stream(audio_file)
self.audio_iters.append(x)
return x