bugfix
This commit is contained in:
parent
f89bc2eaa5
commit
52e6386bc4
41
rtcllm/a2a.py
Normal file
41
rtcllm/a2a.py
Normal file
@ -0,0 +1,41 @@
|
||||
import io
|
||||
from mini_omni.inference import OmniInference
|
||||
from av import AudioFrame
|
||||
|
||||
class LLMAudioStreamTrack(AudioStreamTrack):
|
||||
def __init__(self, ckpt_dir):
|
||||
super().__init__()
|
||||
self.ckpt_dir = ckpt_dir
|
||||
self.oi = OmniInference(ckpt_dir=chpt_dir)
|
||||
self.audio_iters = []
|
||||
self.cur_iters = None
|
||||
|
||||
async def recv(self):
|
||||
b = self.get_audio_bytes()
|
||||
if b is None:
|
||||
return b
|
||||
frame = AudioFrame.from_ndarray(io.BytesIO(b), format='s16', layout='mono')
|
||||
return frame
|
||||
|
||||
def set_cur_audio_iter(self):
|
||||
if len(self.audio_iters) == 0:
|
||||
return False
|
||||
self.cur_iters = self.audio_iters[0]
|
||||
self.audio_iters.remove(self.cur_iters)
|
||||
rteturn True
|
||||
|
||||
def get_audio_bytes(self):
|
||||
if self.cur_iters is None:
|
||||
if not set_cur_audio_iter():
|
||||
return None
|
||||
try:
|
||||
b = next(self.cur_iters)
|
||||
return b
|
||||
except StopIteration:
|
||||
self.cur_iters = None
|
||||
return self.get_audio_bytes()
|
||||
|
||||
def _feed(self, audio_file):
|
||||
x = self.oi.run_AT_batch_stream(audio_file)
|
||||
self.audio_iters.append(x)
|
||||
|
@ -16,7 +16,7 @@ from websockets.client import connect
|
||||
import websockets
|
||||
from stt import asr
|
||||
from vad import AudioTrackVad
|
||||
from aav import MyAudioTrack, MyVideoTrack
|
||||
from a2a import LLMAudioStreamTrack
|
||||
|
||||
videos = ['./1.mp4', './2.mp4']
|
||||
|
||||
@ -45,6 +45,8 @@ class RTCLLM:
|
||||
def __init__(self, ws_url, iceServers):
|
||||
# self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16")
|
||||
self.ws_url = ws_url
|
||||
self.llmtrack = LLMAudioStreamTrack('/d/models/mini-omni')
|
||||
self.feed = awaitify(self.llmtrack._feed)
|
||||
self.iceServers = iceServers
|
||||
self.peers = DictObject()
|
||||
self.vid = 0
|
||||
@ -100,9 +102,7 @@ class RTCLLM:
|
||||
self.onlineList = data.onlineList
|
||||
|
||||
async def vad_voiceend(self, peer, audio):
|
||||
ret = await asr(audio)
|
||||
# await peer.dc.send(txt)
|
||||
|
||||
ret = await self.feed(audio)
|
||||
|
||||
async def auto_accept_call(self, data):
|
||||
opts = DictObject(iceServers=self.iceServers)
|
||||
@ -111,6 +111,7 @@ class RTCLLM:
|
||||
'info':data['from'],
|
||||
'pc':pc
|
||||
})
|
||||
pc.addTrack(self.llmtrack)
|
||||
await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']}))
|
||||
|
||||
async def pc_track(self, peerid, track):
|
||||
@ -122,14 +123,6 @@ class RTCLLM:
|
||||
peer.vadtrack = vadtrack
|
||||
vadtrack.start_vad()
|
||||
|
||||
def play_random_media(self, vt, at):
|
||||
i = random.randint(0,1)
|
||||
player = MediaPlayer(videos[i])
|
||||
vt.set_source(player)
|
||||
at.set_source(player)
|
||||
f = partial(self.play_random_media, vt, at)
|
||||
self.loop.call_later(180, f)
|
||||
|
||||
async def pc_connectionState_changed(self, peerid):
|
||||
peer = self.peers[peerid]
|
||||
pc = peer.pc
|
||||
@ -159,11 +152,6 @@ class RTCLLM:
|
||||
pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id))
|
||||
pc.on('track', partial(self.pc_track, data['from'].id))
|
||||
pc.on('icecandidate', partial(self.on_icecandidate, pc, data['from']))
|
||||
peer.audiotrack = MyAudioTrack()
|
||||
peer.videotrack = MyVideoTrack()
|
||||
pc.addTrack(peer.audiotrack)
|
||||
pc.addTrack(peer.videotrack)
|
||||
self.play_random_media(peer.videotrack, peer.audiotrack)
|
||||
|
||||
offer = RTCSessionDescription(** data.offer)
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
Loading…
Reference in New Issue
Block a user