This commit is contained in:
yumoqing 2024-09-02 18:43:58 +08:00
parent ba9ee97673
commit c225d69aba
4 changed files with 124 additions and 12 deletions

36
rtcllm/aav.py Normal file
View File

@ -0,0 +1,36 @@
import random
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
from aiortc import VideoStreamTrack, AudioStreamTrack
class MyAudioTrack(AudioStreamTrack):
def __init__(self):
super().__init__()
self.source = None
def set_source(self, source):
self.source = source.audio
async def recv(self):
print('MyAudioTrack::recv(): called')
if self.source is None:
return None
f = self.source.recv()
if f is None:
print('MyAudioTrack::recv():return None')
return f
class MyVideoTrack(VideoStreamTrack):
def __init__(self):
super().__init__()
self.source = None
def set_source(self, source):
self.source = source.video
async def recv(self):
print('MyVideoTrack::recv(): called')
if self.source is None:
return None
f = self.source.recv()
return f

View File

@ -1,4 +1,5 @@
import asyncio
import random
import json
from functools import partial
@ -9,19 +10,24 @@ from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, Med
# from websockets.asyncio.client import connect
from websockets.asyncio.client import connect
import websockets
from stt import asr
from vad import AudioTrackVad
from aav import MyAudioTrack, MyVideoTrack
videos = ['./1.mp4', './2.mp4']
class RTCLLM:
def __init__(self, ws_url, iceServers):
self.ws_url = ws_url
self.iceServers = iceServers
self.peers = DictObject()
self.vid = 0
self.info = {
'id':'rtcllm_agent',
'name':'rtcllm_agent'
}
self.dc = None
self.loop = asyncio.get_event_loop()
def get_pc(self, data):
return self.peers[data['from'].id].pc
@ -58,8 +64,10 @@ class RTCLLM:
print(f'{self}, {type(self)}')
self.onlineList = data.onlineList
async def vad_voiceend(self, fn):
print(f'vad_voiceend():{fn=}')
async def vad_voiceend(self, peer, audio):
txt = await asr(audio)
# await peer.dc.send(txt)
async def auto_accept_call(self, data):
opts = DictObject(iceServers=self.iceServers)
@ -74,14 +82,27 @@ class RTCLLM:
peer = self.peers[peerid]
pc = peer.pc
if track.kind == 'audio':
vadtrack = AudioTrackVad(track, stage=3, onvoiceend=self.vad_voiceend)
f = partial(self.vad_voiceend, peer)
vadtrack = AudioTrackVad(track, stage=3, onvoiceend=f)
peer.vadtrack = vadtrack
vadtrack.start_vad()
def play_random_media(self, vt, at):
i = random.randint(0,1)
player = MediaPlayer(videos[i])
vt.set_source(player)
at.set_source(player)
f = partial(self.play_random_media, vt, at)
self.loop.call_later(180, f)
async def pc_connectionState_changed(self, peerid):
peer = self.peers[peerid]
pc = peer.pc
print(f'************************************{pc.connectionState=}')
peer.audiotrack = MyAudioTrack()
peer.videotrack = MyVideoTrack()
pc.addTrack(peer.audiotrack)
pc.addTrack(peer.videotrack)
self.play_random_media(peer.videotrack, peer.audiotrack)
if pc.connectionState == 'connected':
peer.dc = pc.createDataChannel(peer.info.name)
return
@ -97,6 +118,20 @@ class RTCLLM:
if len([k for k in self.peers.keys()]) == 0:
await self.ws.close()
def play_video(self, peerid):
print('play video ........................')
pc = self.peers[peerid].pc
"""
player = MediaPlayer(videos[0])
if player:
pc.addTrack(player.audio)
pc.addTrack(player.video)
"""
player = MediaPlayer(videos[1])
if player:
pc.addTrack(player.audio)
pc.addTrack(player.video)
async def response_offer(self, data):
pc = self.get_pc(data)
if pc is None:
@ -105,10 +140,7 @@ class RTCLLM:
pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id))
pc.on('track', partial(self.pc_track, data['from'].id))
pc.on('icecandidate', partial(self.on_icecandidate, pc))
player = MediaPlayer('./1.mp4')
if player:
pc.addTrack(player.audio)
pc.addTrack(player.video)
# self.play_video(data)
offer = RTCSessionDescription(** data.offer)
await pc.setRemoteDescription(offer)

37
rtcllm/stt.py Normal file
View File

@ -0,0 +1,37 @@
from appPublic.dictObject import DictObject
from appPublic.oauth_client import OAuthClient
desc = {
"path":"/asr/generate",
"method":"POST",
"headers":[
{
"name":"Content-Type",
"value":"application/json"
}
],
"data":[{
"name":"audio",
"value":"${b64audio}"
},{
"name":"model",
"value":"whisper"
}
],
"resp":[
{
"name":"content",
"value":"content"
}
]
}
opts = {
"data":{},
"asr":desc
}
async def asr(audio):
oc = OAuthClient(DictObject(**opts))
r = await oc("http://open-computing.cn", "asr", {"b64audio":audio})
print(f'{r=}')

View File

@ -1,3 +1,4 @@
import base64
from traceback import print_exc
import asyncio
import collections
@ -20,8 +21,8 @@ class AudioTrackVad(MediaStreamTrack):
# self.sample_rate = self.track.getSettings().sampleRate
# frameSize = self.track.getSettings().frameSize
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
self.frame_duration_ms = 0.00008
self.num_padding_frames = 20
self.frame_duration_ms = 0.00002
self.num_padding_frames = 10
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False
self.voiced_frames = []
@ -104,10 +105,16 @@ class AudioTrackVad(MediaStreamTrack):
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = False
audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames])
await self.write_wave(audio_data)
# await self.write_wave(audio_data)
await self.gen_base64(audio_data)
self.ring_buffer.clear()
self.voiced_frames = []
print('end voice .....', len(self.voiced_frames))
async def gen_base64(self, audio_data):
b64 = base64.b64encode(audio_data).decode('utf-8')
if self.onvoiceend:
await self.onvoiceend(b64)
async def write_wave(self, audio_data):
"""Writes a .wav file.