This commit is contained in:
yumoqing 2024-09-10 10:39:38 +08:00
parent b71e533e85
commit a4dc6ef93c
2 changed files with 21 additions and 15 deletions

View File

@ -1,18 +1,19 @@
import io import io
from mini_omni.inference import OmniInference
from av import AudioFrame from av import AudioFrame
from aiortc import VideoStreamTrack, AudioStreamTrack from aiortc import VideoStreamTrack, AudioStreamTrack
class LLMAudioStreamTrack(AudioStreamTrack): class LLMAudioStreamTrack(AudioStreamTrack):
def __init__(self, ckpt_dir): def __init__(self, omni_infer):
super().__init__() super().__init__()
self.ckpt_dir = ckpt_dir self.ckpt_dir = ckpt_dir
self.oi = OmniInference(ckpt_dir=ckpt_dir) # self.oi = OmniInference(ckpt_dir=ckpt_dir)
self.oi = omni_infer
self.audio_iters = [] self.audio_iters = []
self.cur_iters = None self.cur_iters = None
async def recv(self): async def recv(self):
print(f'LLMAudioStreamTrack():recv() called ....')
b = self.get_audio_bytes() b = self.get_audio_bytes()
if b is None: if b is None:
return b return b
@ -27,17 +28,18 @@ class LLMAudioStreamTrack(AudioStreamTrack):
return True return True
def get_audio_bytes(self): def get_audio_bytes(self):
if self.cur_iters is None: if self.cur_iters is None:
if not set_cur_audio_iter(): if not set_cur_audio_iter():
return None return None
try: try:
b = next(self.cur_iters) b = next(self.cur_iters)
return b return b
except StopIteration: except StopIteration:
self.cur_iters = None self.cur_iters = None
return self.get_audio_bytes() return self.get_audio_bytes()
def _feed(self, audio_file): def _feed(self, audio_file):
x = self.oi.run_AT_batch_stream(audio_file) x = self.oi.run_AT_batch_stream(audio_file)
self.audio_iters.append(x) self.audio_iters.append(x)
return x

View File

@ -50,8 +50,8 @@ class RTCLLM:
def __init__(self, ws_url, iceServers): def __init__(self, ws_url, iceServers):
# self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16") # self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16")
self.ws_url = ws_url self.ws_url = ws_url
self.llmtrack = LLMAudioStreamTrack('/d/models/mini-omni') self.omni_infer =
self.feed = awaitify(self.llmtrack._feed) self.omni_infer = OmniInference(ckpt_dir='/d/models/mini-omni')
self.iceServers = iceServers self.iceServers = iceServers
self.peers = DictObject() self.peers = DictObject()
self.vid = 0 self.vid = 0
@ -108,15 +108,19 @@ class RTCLLM:
async def vad_voiceend(self, peer, audio): async def vad_voiceend(self, peer, audio):
if audio is not None: if audio is not None:
ret = await self.feed(audio) feed = awaitify(peer.llmtrack._feed)
ret = await feed(audio)
print(f'self.feed("{audio}") return {ret}') print(f'self.feed("{audio}") return {ret}')
os.remove(audio) os.remove(audio)
async def auto_accept_call(self, data): async def auto_accept_call(self, data):
opts = DictObject(iceServers=self.iceServers) opts = DictObject(iceServers=self.iceServers)
pc = RTCPeerConnection(opts) pc = RTCPeerConnection(opts)
llmtrack = LLMAudioStreamTrack('/d/models/mini-omni')
feed = awaitify(self.llmtrack._feed)
self.peers[data['from'].id] = DictObject(**{ self.peers[data['from'].id] = DictObject(**{
'info':data['from'], 'info':data['from'],
'llmtrack':llmtrack,
'pc':pc 'pc':pc
}) })
pc.addTrack(self.llmtrack) pc.addTrack(self.llmtrack)