This commit is contained in:
yumoqing 2024-09-09 17:34:08 +08:00
parent f89bc2eaa5
commit 52e6386bc4
2 changed files with 46 additions and 17 deletions

41
rtcllm/a2a.py Normal file
View File

@ -0,0 +1,41 @@
import io
from mini_omni.inference import OmniInference
from av import AudioFrame
class LLMAudioStreamTrack(AudioStreamTrack):
def __init__(self, ckpt_dir):
super().__init__()
self.ckpt_dir = ckpt_dir
self.oi = OmniInference(ckpt_dir=chpt_dir)
self.audio_iters = []
self.cur_iters = None
async def recv(self):
b = self.get_audio_bytes()
if b is None:
return b
frame = AudioFrame.from_ndarray(io.BytesIO(b), format='s16', layout='mono')
return frame
def set_cur_audio_iter(self):
if len(self.audio_iters) == 0:
return False
self.cur_iters = self.audio_iters[0]
self.audio_iters.remove(self.cur_iters)
rteturn True
def get_audio_bytes(self):
if self.cur_iters is None:
if not set_cur_audio_iter():
return None
try:
b = next(self.cur_iters)
return b
except StopIteration:
self.cur_iters = None
return self.get_audio_bytes()
def _feed(self, audio_file):
x = self.oi.run_AT_batch_stream(audio_file)
self.audio_iters.append(x)

View File

@ -16,7 +16,7 @@ from websockets.client import connect
import websockets import websockets
from stt import asr from stt import asr
from vad import AudioTrackVad from vad import AudioTrackVad
from aav import MyAudioTrack, MyVideoTrack from a2a import LLMAudioStreamTrack
videos = ['./1.mp4', './2.mp4'] videos = ['./1.mp4', './2.mp4']
@ -45,6 +45,8 @@ class RTCLLM:
def __init__(self, ws_url, iceServers): def __init__(self, ws_url, iceServers):
# self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16") # self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16")
self.ws_url = ws_url self.ws_url = ws_url
self.llmtrack = LLMAudioStreamTrack('/d/models/mini-omni')
self.feed = awaitify(self.llmtrack._feed)
self.iceServers = iceServers self.iceServers = iceServers
self.peers = DictObject() self.peers = DictObject()
self.vid = 0 self.vid = 0
@ -100,9 +102,7 @@ class RTCLLM:
self.onlineList = data.onlineList self.onlineList = data.onlineList
async def vad_voiceend(self, peer, audio): async def vad_voiceend(self, peer, audio):
ret = await asr(audio) ret = await self.feed(audio)
# await peer.dc.send(txt)
async def auto_accept_call(self, data): async def auto_accept_call(self, data):
opts = DictObject(iceServers=self.iceServers) opts = DictObject(iceServers=self.iceServers)
@ -111,6 +111,7 @@ class RTCLLM:
'info':data['from'], 'info':data['from'],
'pc':pc 'pc':pc
}) })
pc.addTrack(self.llmtrack)
await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']})) await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']}))
async def pc_track(self, peerid, track): async def pc_track(self, peerid, track):
@ -122,14 +123,6 @@ class RTCLLM:
peer.vadtrack = vadtrack peer.vadtrack = vadtrack
vadtrack.start_vad() vadtrack.start_vad()
def play_random_media(self, vt, at):
i = random.randint(0,1)
player = MediaPlayer(videos[i])
vt.set_source(player)
at.set_source(player)
f = partial(self.play_random_media, vt, at)
self.loop.call_later(180, f)
async def pc_connectionState_changed(self, peerid): async def pc_connectionState_changed(self, peerid):
peer = self.peers[peerid] peer = self.peers[peerid]
pc = peer.pc pc = peer.pc
@ -159,11 +152,6 @@ class RTCLLM:
pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id)) pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id))
pc.on('track', partial(self.pc_track, data['from'].id)) pc.on('track', partial(self.pc_track, data['from'].id))
pc.on('icecandidate', partial(self.on_icecandidate, pc, data['from'])) pc.on('icecandidate', partial(self.on_icecandidate, pc, data['from']))
peer.audiotrack = MyAudioTrack()
peer.videotrack = MyVideoTrack()
pc.addTrack(peer.audiotrack)
pc.addTrack(peer.videotrack)
self.play_random_media(peer.videotrack, peer.audiotrack)
offer = RTCSessionDescription(** data.offer) offer = RTCSessionDescription(** data.offer)
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)