bugfix
This commit is contained in:
parent
b71e533e85
commit
a4dc6ef93c
@ -1,18 +1,19 @@
|
|||||||
import io
|
import io
|
||||||
from mini_omni.inference import OmniInference
|
|
||||||
from av import AudioFrame
|
from av import AudioFrame
|
||||||
|
|
||||||
from aiortc import VideoStreamTrack, AudioStreamTrack
|
from aiortc import VideoStreamTrack, AudioStreamTrack
|
||||||
|
|
||||||
class LLMAudioStreamTrack(AudioStreamTrack):
|
class LLMAudioStreamTrack(AudioStreamTrack):
|
||||||
def __init__(self, ckpt_dir):
|
def __init__(self, omni_infer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ckpt_dir = ckpt_dir
|
self.ckpt_dir = ckpt_dir
|
||||||
self.oi = OmniInference(ckpt_dir=ckpt_dir)
|
# self.oi = OmniInference(ckpt_dir=ckpt_dir)
|
||||||
|
self.oi = omni_infer
|
||||||
self.audio_iters = []
|
self.audio_iters = []
|
||||||
self.cur_iters = None
|
self.cur_iters = None
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
|
print(f'LLMAudioStreamTrack():recv() called ....')
|
||||||
b = self.get_audio_bytes()
|
b = self.get_audio_bytes()
|
||||||
if b is None:
|
if b is None:
|
||||||
return b
|
return b
|
||||||
@ -27,17 +28,18 @@ class LLMAudioStreamTrack(AudioStreamTrack):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def get_audio_bytes(self):
|
def get_audio_bytes(self):
|
||||||
if self.cur_iters is None:
|
if self.cur_iters is None:
|
||||||
if not set_cur_audio_iter():
|
if not set_cur_audio_iter():
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
b = next(self.cur_iters)
|
b = next(self.cur_iters)
|
||||||
return b
|
return b
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.cur_iters = None
|
self.cur_iters = None
|
||||||
return self.get_audio_bytes()
|
return self.get_audio_bytes()
|
||||||
|
|
||||||
def _feed(self, audio_file):
|
def _feed(self, audio_file):
|
||||||
x = self.oi.run_AT_batch_stream(audio_file)
|
x = self.oi.run_AT_batch_stream(audio_file)
|
||||||
self.audio_iters.append(x)
|
self.audio_iters.append(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@ -50,8 +50,8 @@ class RTCLLM:
|
|||||||
def __init__(self, ws_url, iceServers):
|
def __init__(self, ws_url, iceServers):
|
||||||
# self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16")
|
# self.stt_model = WhisperModel('large-v3', device="cuda", compute_type="float16")
|
||||||
self.ws_url = ws_url
|
self.ws_url = ws_url
|
||||||
self.llmtrack = LLMAudioStreamTrack('/d/models/mini-omni')
|
self.omni_infer =
|
||||||
self.feed = awaitify(self.llmtrack._feed)
|
self.omni_infer = OmniInference(ckpt_dir='/d/models/mini-omni')
|
||||||
self.iceServers = iceServers
|
self.iceServers = iceServers
|
||||||
self.peers = DictObject()
|
self.peers = DictObject()
|
||||||
self.vid = 0
|
self.vid = 0
|
||||||
@ -108,15 +108,19 @@ class RTCLLM:
|
|||||||
|
|
||||||
async def vad_voiceend(self, peer, audio):
|
async def vad_voiceend(self, peer, audio):
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
ret = await self.feed(audio)
|
feed = awaitify(peer.llmtrack._feed)
|
||||||
|
ret = await feed(audio)
|
||||||
print(f'self.feed("{audio}") return {ret}')
|
print(f'self.feed("{audio}") return {ret}')
|
||||||
os.remove(audio)
|
os.remove(audio)
|
||||||
|
|
||||||
async def auto_accept_call(self, data):
|
async def auto_accept_call(self, data):
|
||||||
opts = DictObject(iceServers=self.iceServers)
|
opts = DictObject(iceServers=self.iceServers)
|
||||||
pc = RTCPeerConnection(opts)
|
pc = RTCPeerConnection(opts)
|
||||||
|
llmtrack = LLMAudioStreamTrack('/d/models/mini-omni')
|
||||||
|
feed = awaitify(self.llmtrack._feed)
|
||||||
self.peers[data['from'].id] = DictObject(**{
|
self.peers[data['from'].id] = DictObject(**{
|
||||||
'info':data['from'],
|
'info':data['from'],
|
||||||
|
'llmtrack':llmtrack,
|
||||||
'pc':pc
|
'pc':pc
|
||||||
})
|
})
|
||||||
pc.addTrack(self.llmtrack)
|
pc.addTrack(self.llmtrack)
|
||||||
|
Loading…
Reference in New Issue
Block a user