This commit is contained in:
yumoqing 2024-09-02 18:43:58 +08:00
parent ba9ee97673
commit c225d69aba
4 changed files with 124 additions and 12 deletions

36
rtcllm/aav.py Normal file
View File

@ -0,0 +1,36 @@
import random
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
from aiortc import VideoStreamTrack, AudioStreamTrack
class MyAudioTrack(AudioStreamTrack):
def __init__(self):
super().__init__()
self.source = None
def set_source(self, source):
self.source = source.audio
async def recv(self):
print('MyAudioTrack::recv(): called')
if self.source is None:
return None
f = self.source.recv()
if f is None:
print('MyAudioTrack::recv():return None')
return f
class MyVideoTrack(VideoStreamTrack):
def __init__(self):
super().__init__()
self.source = None
def set_source(self, source):
self.source = source.video
async def recv(self):
print('MyVideoTrack::recv(): called')
if self.source is None:
return None
f = self.source.recv()
return f

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import random
import json import json
from functools import partial from functools import partial
@ -9,19 +10,24 @@ from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, Med
# from websockets.asyncio.client import connect # from websockets.asyncio.client import connect
from websockets.asyncio.client import connect from websockets.asyncio.client import connect
import websockets import websockets
from stt import asr
from vad import AudioTrackVad from vad import AudioTrackVad
from aav import MyAudioTrack, MyVideoTrack
videos = ['./1.mp4', './2.mp4']
class RTCLLM: class RTCLLM:
def __init__(self, ws_url, iceServers): def __init__(self, ws_url, iceServers):
self.ws_url = ws_url self.ws_url = ws_url
self.iceServers = iceServers self.iceServers = iceServers
self.peers = DictObject() self.peers = DictObject()
self.vid = 0
self.info = { self.info = {
'id':'rtcllm_agent', 'id':'rtcllm_agent',
'name':'rtcllm_agent' 'name':'rtcllm_agent'
} }
self.dc = None self.dc = None
self.loop = asyncio.get_event_loop()
def get_pc(self, data): def get_pc(self, data):
return self.peers[data['from'].id].pc return self.peers[data['from'].id].pc
@ -58,8 +64,10 @@ class RTCLLM:
print(f'{self}, {type(self)}') print(f'{self}, {type(self)}')
self.onlineList = data.onlineList self.onlineList = data.onlineList
async def vad_voiceend(self, fn): async def vad_voiceend(self, peer, audio):
print(f'vad_voiceend():{fn=}') txt = await asr(audio)
# await peer.dc.send(txt)
async def auto_accept_call(self, data): async def auto_accept_call(self, data):
opts = DictObject(iceServers=self.iceServers) opts = DictObject(iceServers=self.iceServers)
@ -74,14 +82,27 @@ class RTCLLM:
peer = self.peers[peerid] peer = self.peers[peerid]
pc = peer.pc pc = peer.pc
if track.kind == 'audio': if track.kind == 'audio':
vadtrack = AudioTrackVad(track, stage=3, onvoiceend=self.vad_voiceend) f = partial(self.vad_voiceend, peer)
vadtrack = AudioTrackVad(track, stage=3, onvoiceend=f)
peer.vadtrack = vadtrack peer.vadtrack = vadtrack
vadtrack.start_vad() vadtrack.start_vad()
def play_random_media(self, vt, at):
i = random.randint(0,1)
player = MediaPlayer(videos[i])
vt.set_source(player)
at.set_source(player)
f = partial(self.play_random_media, vt, at)
self.loop.call_later(180, f)
async def pc_connectionState_changed(self, peerid): async def pc_connectionState_changed(self, peerid):
peer = self.peers[peerid] peer = self.peers[peerid]
pc = peer.pc pc = peer.pc
print(f'************************************{pc.connectionState=}') peer.audiotrack = MyAudioTrack()
peer.videotrack = MyVideoTrack()
pc.addTrack(peer.audiotrack)
pc.addTrack(peer.videotrack)
self.play_random_media(peer.videotrack, peer.audiotrack)
if pc.connectionState == 'connected': if pc.connectionState == 'connected':
peer.dc = pc.createDataChannel(peer.info.name) peer.dc = pc.createDataChannel(peer.info.name)
return return
@ -97,6 +118,20 @@ class RTCLLM:
if len([k for k in self.peers.keys()]) == 0: if len([k for k in self.peers.keys()]) == 0:
await self.ws.close() await self.ws.close()
def play_video(self, peerid):
print('play video ........................')
pc = self.peers[peerid].pc
"""
player = MediaPlayer(videos[0])
if player:
pc.addTrack(player.audio)
pc.addTrack(player.video)
"""
player = MediaPlayer(videos[1])
if player:
pc.addTrack(player.audio)
pc.addTrack(player.video)
async def response_offer(self, data): async def response_offer(self, data):
pc = self.get_pc(data) pc = self.get_pc(data)
if pc is None: if pc is None:
@ -105,10 +140,7 @@ class RTCLLM:
pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id)) pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id))
pc.on('track', partial(self.pc_track, data['from'].id)) pc.on('track', partial(self.pc_track, data['from'].id))
pc.on('icecandidate', partial(self.on_icecandidate, pc)) pc.on('icecandidate', partial(self.on_icecandidate, pc))
player = MediaPlayer('./1.mp4') # self.play_video(data)
if player:
pc.addTrack(player.audio)
pc.addTrack(player.video)
offer = RTCSessionDescription(** data.offer) offer = RTCSessionDescription(** data.offer)
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)

37
rtcllm/stt.py Normal file
View File

@ -0,0 +1,37 @@
from appPublic.dictObject import DictObject
from appPublic.oauth_client import OAuthClient
desc = {
"path":"/asr/generate",
"method":"POST",
"headers":[
{
"name":"Content-Type",
"value":"application/json"
}
],
"data":[{
"name":"audio",
"value":"${b64audio}"
},{
"name":"model",
"value":"whisper"
}
],
"resp":[
{
"name":"content",
"value":"content"
}
]
}
opts = {
"data":{},
"asr":desc
}
async def asr(audio):
oc = OAuthClient(DictObject(**opts))
r = await oc("http://open-computing.cn", "asr", {"b64audio":audio})
print(f'{r=}')

View File

@ -1,3 +1,4 @@
import base64
from traceback import print_exc from traceback import print_exc
import asyncio import asyncio
import collections import collections
@ -20,8 +21,8 @@ class AudioTrackVad(MediaStreamTrack):
# self.sample_rate = self.track.getSettings().sampleRate # self.sample_rate = self.track.getSettings().sampleRate
# frameSize = self.track.getSettings().frameSize # frameSize = self.track.getSettings().frameSize
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate # self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
self.frame_duration_ms = 0.00008 self.frame_duration_ms = 0.00002
self.num_padding_frames = 20 self.num_padding_frames = 10
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False self.triggered = False
self.voiced_frames = [] self.voiced_frames = []
@ -104,10 +105,16 @@ class AudioTrackVad(MediaStreamTrack):
if num_unvoiced > 0.9 * self.ring_buffer.maxlen: if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = False self.triggered = False
audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames]) audio_data = b''.join([self.frame2bytes(f) for f in self.voiced_frames])
await self.write_wave(audio_data) # await self.write_wave(audio_data)
await self.gen_base64(audio_data)
self.ring_buffer.clear() self.ring_buffer.clear()
self.voiced_frames = [] self.voiced_frames = []
print('end voice .....', len(self.voiced_frames)) print('end voice .....', len(self.voiced_frames))
async def gen_base64(self, audio_data):
b64 = base64.b64encode(audio_data).decode('utf-8')
if self.onvoiceend:
await self.onvoiceend(b64)
async def write_wave(self, audio_data): async def write_wave(self, audio_data):
"""Writes a .wav file. """Writes a .wav file.