bugfix
This commit is contained in:
parent
ba9ee97673
commit
c225d69aba
36
rtcllm/aav.py
Normal file
36
rtcllm/aav.py
Normal file
@ -0,0 +1,36 @@
|
||||
import random
|
||||
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||
from aiortc import VideoStreamTrack, AudioStreamTrack
|
||||
|
||||
class MyAudioTrack(AudioStreamTrack):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.source = None
|
||||
|
||||
def set_source(self, source):
|
||||
self.source = source.audio
|
||||
|
||||
async def recv(self):
|
||||
print('MyAudioTrack::recv(): called')
|
||||
|
||||
if self.source is None:
|
||||
return None
|
||||
f = self.source.recv()
|
||||
if f is None:
|
||||
print('MyAudioTrack::recv():return None')
|
||||
return f
|
||||
|
||||
class MyVideoTrack(VideoStreamTrack):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.source = None
|
||||
|
||||
def set_source(self, source):
|
||||
self.source = source.video
|
||||
|
||||
async def recv(self):
|
||||
print('MyVideoTrack::recv(): called')
|
||||
if self.source is None:
|
||||
return None
|
||||
f = self.source.recv()
|
||||
return f
|
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import random
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
@ -9,19 +10,24 @@ from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, Med
|
||||
# from websockets.asyncio.client import connect
|
||||
from websockets.asyncio.client import connect
|
||||
import websockets
|
||||
|
||||
from stt import asr
|
||||
from vad import AudioTrackVad
|
||||
from aav import MyAudioTrack, MyVideoTrack
|
||||
|
||||
videos = ['./1.mp4', './2.mp4']
|
||||
|
||||
class RTCLLM:
|
||||
def __init__(self, ws_url, iceServers):
|
||||
self.ws_url = ws_url
|
||||
self.iceServers = iceServers
|
||||
self.peers = DictObject()
|
||||
self.vid = 0
|
||||
self.info = {
|
||||
'id':'rtcllm_agent',
|
||||
'name':'rtcllm_agent'
|
||||
}
|
||||
self.dc = None
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def get_pc(self, data):
|
||||
return self.peers[data['from'].id].pc
|
||||
@ -58,8 +64,10 @@ class RTCLLM:
|
||||
print(f'{self}, {type(self)}')
|
||||
self.onlineList = data.onlineList
|
||||
|
||||
async def vad_voiceend(self, fn):
|
||||
print(f'vad_voiceend():{fn=}')
|
||||
async def vad_voiceend(self, peer, audio):
|
||||
txt = await asr(audio)
|
||||
# await peer.dc.send(txt)
|
||||
|
||||
|
||||
async def auto_accept_call(self, data):
|
||||
opts = DictObject(iceServers=self.iceServers)
|
||||
@ -74,14 +82,27 @@ class RTCLLM:
|
||||
peer = self.peers[peerid]
|
||||
pc = peer.pc
|
||||
if track.kind == 'audio':
|
||||
vadtrack = AudioTrackVad(track, stage=3, onvoiceend=self.vad_voiceend)
|
||||
f = partial(self.vad_voiceend, peer)
|
||||
vadtrack = AudioTrackVad(track, stage=3, onvoiceend=f)
|
||||
peer.vadtrack = vadtrack
|
||||
vadtrack.start_vad()
|
||||
|
||||
def play_random_media(self, vt, at):
|
||||
i = random.randint(0,1)
|
||||
player = MediaPlayer(videos[i])
|
||||
vt.set_source(player)
|
||||
at.set_source(player)
|
||||
f = partial(self.play_random_media, vt, at)
|
||||
self.loop.call_later(180, f)
|
||||
|
||||
async def pc_connectionState_changed(self, peerid):
|
||||
peer = self.peers[peerid]
|
||||
pc = peer.pc
|
||||
print(f'************************************{pc.connectionState=}')
|
||||
peer.audiotrack = MyAudioTrack()
|
||||
peer.videotrack = MyVideoTrack()
|
||||
pc.addTrack(peer.audiotrack)
|
||||
pc.addTrack(peer.videotrack)
|
||||
self.play_random_media(peer.videotrack, peer.audiotrack)
|
||||
if pc.connectionState == 'connected':
|
||||
peer.dc = pc.createDataChannel(peer.info.name)
|
||||
return
|
||||
@ -97,6 +118,20 @@ class RTCLLM:
|
||||
if len([k for k in self.peers.keys()]) == 0:
|
||||
await self.ws.close()
|
||||
|
||||
def play_video(self, peerid):
|
||||
print('play video ........................')
|
||||
pc = self.peers[peerid].pc
|
||||
"""
|
||||
player = MediaPlayer(videos[0])
|
||||
if player:
|
||||
pc.addTrack(player.audio)
|
||||
pc.addTrack(player.video)
|
||||
"""
|
||||
player = MediaPlayer(videos[1])
|
||||
if player:
|
||||
pc.addTrack(player.audio)
|
||||
pc.addTrack(player.video)
|
||||
|
||||
async def response_offer(self, data):
|
||||
pc = self.get_pc(data)
|
||||
if pc is None:
|
||||
@ -105,10 +140,7 @@ class RTCLLM:
|
||||
pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id))
|
||||
pc.on('track', partial(self.pc_track, data['from'].id))
|
||||
pc.on('icecandidate', partial(self.on_icecandidate, pc))
|
||||
player = MediaPlayer('./1.mp4')
|
||||
if player:
|
||||
pc.addTrack(player.audio)
|
||||
pc.addTrack(player.video)
|
||||
# self.play_video(data)
|
||||
|
||||
offer = RTCSessionDescription(** data.offer)
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
37
rtcllm/stt.py
Normal file
37
rtcllm/stt.py
Normal file
@ -0,0 +1,37 @@
|
||||
from appPublic.dictObject import DictObject
|
||||
from appPublic.oauth_client import OAuthClient
|
||||
|
||||
desc = {
|
||||
"path":"/asr/generate",
|
||||
"method":"POST",
|
||||
"headers":[
|
||||
{
|
||||
"name":"Content-Type",
|
||||
"value":"application/json"
|
||||
}
|
||||
],
|
||||
"data":[{
|
||||
"name":"audio",
|
||||
"value":"${b64audio}"
|
||||
},{
|
||||
"name":"model",
|
||||
"value":"whisper"
|
||||
}
|
||||
],
|
||||
"resp":[
|
||||
{
|
||||
"name":"content",
|
||||
"value":"content"
|
||||
}
|
||||
]
|
||||
}
|
||||
opts = {
|
||||
"data":{},
|
||||
"asr":desc
|
||||
}
|
||||
|
||||
async def asr(audio):
|
||||
oc = OAuthClient(DictObject(**opts))
|
||||
r = await oc("http://open-computing.cn", "asr", {"b64audio":audio})
|
||||
print(f'{r=}')
|
||||
|
@ -1,3 +1,4 @@
|
||||
import base64
|
||||
from traceback import print_exc
|
||||
import asyncio
|
||||
import collections
|
||||
@ -20,8 +21,8 @@ class AudioTrackVad(MediaStreamTrack):
|
||||
# self.sample_rate = self.track.getSettings().sampleRate
|
||||
# frameSize = self.track.getSettings().frameSize
|
||||
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
|
||||
self.frame_duration_ms = 0.00008
|
||||
self.num_padding_frames = 20
|
||||
self.frame_duration_ms = 0.00002
|
||||
self.num_padding_frames = 10
|
||||
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
||||
self.triggered = False
|
||||
self.voiced_frames = []
|
||||
@ -104,10 +105,16 @@ class AudioTrackVad(MediaStreamTrack):
|
||||
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
|
||||
self.triggered = False
|
||||
audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames])
|
||||
await self.write_wave(audio_data)
|
||||
# await self.write_wave(audio_data)
|
||||
await self.gen_base64(audio_data)
|
||||
self.ring_buffer.clear()
|
||||
self.voiced_frames = []
|
||||
|
||||
print('end voice .....', len(self.voiced_frames))
|
||||
async def gen_base64(self, audio_data):
|
||||
b64 = base64.b64encode(audio_data).decode('utf-8')
|
||||
if self.onvoiceend:
|
||||
await self.onvoiceend(b64)
|
||||
|
||||
async def write_wave(self, audio_data):
|
||||
"""Writes a .wav file.
|
||||
|
Loading…
Reference in New Issue
Block a user