bugfix
This commit is contained in:
parent
84f61c8b71
commit
9e17d711b3
214
rtcllm/examples.py
Normal file
214
rtcllm/examples.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from aiohttp import web
|
||||||
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||||
|
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||||
|
from av import VideoFrame
|
||||||
|
|
||||||
|
ROOT = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
logger = logging.getLogger("pc")
|
||||||
|
pcs = set()
|
||||||
|
relay = MediaRelay()
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTransformTrack(MediaStreamTrack):
|
||||||
|
"""
|
||||||
|
A video stream track that transforms frames from an another track.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind = "video"
|
||||||
|
|
||||||
|
def __init__(self, track, transform):
|
||||||
|
super().__init__() # don't forget this!
|
||||||
|
self.track = track
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
frame = await self.track.recv()
|
||||||
|
|
||||||
|
if self.transform == "cartoon":
|
||||||
|
img = frame.to_ndarray(format="bgr24")
|
||||||
|
|
||||||
|
# prepare color
|
||||||
|
img_color = cv2.pyrDown(cv2.pyrDown(img))
|
||||||
|
for _ in range(6):
|
||||||
|
img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
|
||||||
|
img_color = cv2.pyrUp(cv2.pyrUp(img_color))
|
||||||
|
|
||||||
|
# prepare edges
|
||||||
|
img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||||
|
img_edges = cv2.adaptiveThreshold(
|
||||||
|
cv2.medianBlur(img_edges, 7),
|
||||||
|
255,
|
||||||
|
cv2.ADAPTIVE_THRESH_MEAN_C,
|
||||||
|
cv2.THRESH_BINARY,
|
||||||
|
9,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB)
|
||||||
|
|
||||||
|
# combine color and edges
|
||||||
|
img = cv2.bitwise_and(img_color, img_edges)
|
||||||
|
|
||||||
|
# rebuild a VideoFrame, preserving timing information
|
||||||
|
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
||||||
|
new_frame.pts = frame.pts
|
||||||
|
new_frame.time_base = frame.time_base
|
||||||
|
return new_frame
|
||||||
|
elif self.transform == "edges":
|
||||||
|
# perform edge detection
|
||||||
|
img = frame.to_ndarray(format="bgr24")
|
||||||
|
img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
# rebuild a VideoFrame, preserving timing information
|
||||||
|
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
||||||
|
new_frame.pts = frame.pts
|
||||||
|
new_frame.time_base = frame.time_base
|
||||||
|
return new_frame
|
||||||
|
elif self.transform == "rotate":
|
||||||
|
# rotate image
|
||||||
|
img = frame.to_ndarray(format="bgr24")
|
||||||
|
rows, cols, _ = img.shape
|
||||||
|
M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1)
|
||||||
|
img = cv2.warpAffine(img, M, (cols, rows))
|
||||||
|
|
||||||
|
# rebuild a VideoFrame, preserving timing information
|
||||||
|
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
||||||
|
new_frame.pts = frame.pts
|
||||||
|
new_frame.time_base = frame.time_base
|
||||||
|
return new_frame
|
||||||
|
else:
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
async def index(request):
|
||||||
|
content = open(os.path.join(ROOT, "index.html"), "r").read()
|
||||||
|
return web.Response(content_type="text/html", text=content)
|
||||||
|
|
||||||
|
|
||||||
|
async def javascript(request):
|
||||||
|
content = open(os.path.join(ROOT, "client.js"), "r").read()
|
||||||
|
return web.Response(content_type="application/javascript", text=content)
|
||||||
|
|
||||||
|
|
||||||
|
async def offer(request):
|
||||||
|
params = await request.json()
|
||||||
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||||
|
|
||||||
|
pc = RTCPeerConnection()
|
||||||
|
pc_id = "PeerConnection(%s)" % uuid.uuid4()
|
||||||
|
pcs.add(pc)
|
||||||
|
|
||||||
|
def log_info(msg, *args):
|
||||||
|
logger.info(pc_id + " " + msg, *args)
|
||||||
|
|
||||||
|
log_info("Created for %s", request.remote)
|
||||||
|
|
||||||
|
# prepare local media
|
||||||
|
player = MediaPlayer(os.path.join(ROOT, "demo-instruct.wav"))
|
||||||
|
if args.record_to:
|
||||||
|
recorder = MediaRecorder(args.record_to)
|
||||||
|
else:
|
||||||
|
recorder = MediaBlackhole()
|
||||||
|
|
||||||
|
@pc.on("datachannel")
|
||||||
|
def on_datachannel(channel):
|
||||||
|
@channel.on("message")
|
||||||
|
def on_message(message):
|
||||||
|
if isinstance(message, str) and message.startswith("ping"):
|
||||||
|
channel.send("pong" + message[4:])
|
||||||
|
|
||||||
|
@pc.on("connectionstatechange")
|
||||||
|
async def on_connectionstatechange():
|
||||||
|
log_info("Connection state is %s", pc.connectionState)
|
||||||
|
if pc.connectionState == "failed":
|
||||||
|
await pc.close()
|
||||||
|
pcs.discard(pc)
|
||||||
|
|
||||||
|
@pc.on("track")
|
||||||
|
def on_track(track):
|
||||||
|
log_info("Track %s received", track.kind)
|
||||||
|
|
||||||
|
if track.kind == "audio":
|
||||||
|
pc.addTrack(player.audio)
|
||||||
|
recorder.addTrack(track)
|
||||||
|
elif track.kind == "video":
|
||||||
|
pc.addTrack(
|
||||||
|
VideoTransformTrack(
|
||||||
|
relay.subscribe(track), transform=params["video_transform"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if args.record_to:
|
||||||
|
recorder.addTrack(relay.subscribe(track))
|
||||||
|
|
||||||
|
@track.on("ended")
|
||||||
|
async def on_ended():
|
||||||
|
log_info("Track %s ended", track.kind)
|
||||||
|
await recorder.stop()
|
||||||
|
|
||||||
|
# handle offer
|
||||||
|
await pc.setRemoteDescription(offer)
|
||||||
|
await recorder.start()
|
||||||
|
|
||||||
|
# send answer
|
||||||
|
answer = await pc.createAnswer()
|
||||||
|
await pc.setLocalDescription(answer)
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
content_type="application/json",
|
||||||
|
text=json.dumps(
|
||||||
|
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown(app):
|
||||||
|
# close peer connections
|
||||||
|
coros = [pc.close() for pc in pcs]
|
||||||
|
await asyncio.gather(*coros)
|
||||||
|
pcs.clear()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="WebRTC audio / video / data-channels demo"
|
||||||
|
)
|
||||||
|
parser.add_argument("--cert-file", help="SSL certificate file (for HTTPS)")
|
||||||
|
parser.add_argument("--key-file", help="SSL key file (for HTTPS)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", default="0.0.0.0", help="Host for HTTP server (default: 0.0.0.0)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8080, help="Port for HTTP server (default: 8080)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--record-to", help="Write received media to a file.")
|
||||||
|
parser.add_argument("--verbose", "-v", action="count")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
if args.cert_file:
|
||||||
|
ssl_context = ssl.SSLContext()
|
||||||
|
ssl_context.load_cert_chain(args.cert_file, args.key_file)
|
||||||
|
else:
|
||||||
|
ssl_context = None
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.on_shutdown.append(on_shutdown)
|
||||||
|
app.router.add_get("/", index)
|
||||||
|
app.router.add_get("/client.js", javascript)
|
||||||
|
app.router.add_post("/offer", offer)
|
||||||
|
web.run_app(
|
||||||
|
app, access_log=None, host=args.host, port=args.port, ssl_context=ssl_context
|
||||||
|
)
|
73
rtcllm/vad.py
Normal file
73
rtcllm/vad.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import collections
|
||||||
|
import contextlib
|
||||||
|
from aiortc import MediaStreamTrack
|
||||||
|
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||||
|
import webrtcvad
|
||||||
|
import wave
|
||||||
|
|
||||||
|
class AudioTrackVad(MediaStreamTrack):
|
||||||
|
def __init__(self, track, stage=3, onvoiceend=None):
|
||||||
|
super().__init__()
|
||||||
|
self.track = track
|
||||||
|
self.onvoiceend = onvoiceend
|
||||||
|
self.vad = webrtcvad.Vad(stage)
|
||||||
|
self.sample_rate = self.track.getSettings().sampleRate
|
||||||
|
frameSize = self.track.getSettings().frameSize
|
||||||
|
self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
|
||||||
|
self.num_padding_frames = 10
|
||||||
|
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
||||||
|
self.triggered = False
|
||||||
|
self.voiced_frames = []
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
frame = await self.track.recv()
|
||||||
|
self.vad_check(frame)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
async def vad_check(self, frame):
|
||||||
|
is_speech = self.vad.is_speech(frame, self.sample_rate)
|
||||||
|
if not self.triggered:
|
||||||
|
self.ring_buffer.append((frame, is_speech))
|
||||||
|
num_voiced = len([f for f, speech in self.ring_buffer if speech])
|
||||||
|
# If we're NOTTRIGGERED and more than 90% of the frames in
|
||||||
|
# the ring buffer are voiced frames, then enter the
|
||||||
|
# TRIGGERED state.
|
||||||
|
if num_voiced > 0.9 * self.ring_buffer.maxlen:
|
||||||
|
self.triggered = True
|
||||||
|
# We want to yield all the audio we see from now until
|
||||||
|
# we are NOTTRIGGERED, but we have to start with the
|
||||||
|
# audio that's already in the ring buffer.
|
||||||
|
for f, s in self.ring_buffer:
|
||||||
|
self.voiced_frames.append(f)
|
||||||
|
self.ring_buffer.clear()
|
||||||
|
else:
|
||||||
|
# We're in the TRIGGERED state, so collect the audio data
|
||||||
|
# and add it to the ring buffer.
|
||||||
|
self.voiced_frames.append(frame)
|
||||||
|
self.ring_buffer.append((frame, is_speech))
|
||||||
|
num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
|
||||||
|
# If more than 90% of the frames in the ring buffer are
|
||||||
|
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||||
|
# audio we've collected.
|
||||||
|
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
|
||||||
|
self.triggered = False
|
||||||
|
audio_data = b''.join([f.bytes for f in voiced_frames])
|
||||||
|
self.write_wave(audio_data)
|
||||||
|
self.ring_buffer.clear()
|
||||||
|
voiced_frames = []
|
||||||
|
|
||||||
|
async def write_wave(self, audio_data):
|
||||||
|
"""Writes a .wav file.
|
||||||
|
|
||||||
|
Takes path, PCM audio data, and sample rate.
|
||||||
|
"""
|
||||||
|
path = make_temp(subfix='.wav')
|
||||||
|
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||||
|
wf.setnchannels(1)
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.setframerate(self.sample_rate)
|
||||||
|
wf.writeframes(audio)
|
||||||
|
|
||||||
|
if self.onvoiceend:
|
||||||
|
self.onvoiceend(path)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user