chore: initial commit
This commit is contained in:
141
choir-mixer/main.py
Normal file
141
choir-mixer/main.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
from livekit import agents, rtc
|
||||
from mixer import normalize_rms, mix_streams, soft_limit
|
||||
|
||||
logger = logging.getLogger("choir-mixer")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
SAMPLE_RATE = 48000
|
||||
NUM_CHANNELS = 1
|
||||
FRAME_DURATION_MS = 20
|
||||
SAMPLES_PER_FRAME = SAMPLE_RATE * FRAME_DURATION_MS // 1000 # 960
|
||||
MAX_STREAMS = 6
|
||||
TARGET_DBFS = -20.0
|
||||
FRAME_MAX_AGE_S = 0.06 # Only use frames received in the last 60ms
|
||||
|
||||
server = agents.AgentServer()
|
||||
|
||||
|
||||
@server.rtc_session(agent_name="choir-mixer")
|
||||
async def choir_mixer_agent(ctx: agents.JobContext):
|
||||
"""Choir mixer agent: subscribes to participant audio, mixes, and publishes."""
|
||||
room = ctx.room
|
||||
|
||||
# Track active audio streams: {track_sid: AudioStream}
|
||||
audio_streams: dict[str, rtc.AudioStream] = {}
|
||||
# Latest frame per track: {track_sid: (timestamp, np.ndarray)}
|
||||
latest_frames: dict[str, tuple[float, np.ndarray]] = {}
|
||||
lock = asyncio.Lock()
|
||||
|
||||
# Set up audio output
|
||||
source = rtc.AudioSource(SAMPLE_RATE, NUM_CHANNELS)
|
||||
track = rtc.LocalAudioTrack.create_audio_track("choir_mix", source)
|
||||
options = rtc.TrackPublishOptions()
|
||||
options.source = rtc.TrackSource.SOURCE_MICROPHONE
|
||||
await ctx.connect(auto_subscribe=agents.AutoSubscribe.AUDIO_ONLY)
|
||||
await room.local_participant.publish_track(track, options)
|
||||
logger.info("Choir mixer joined room: %s", room.name)
|
||||
|
||||
async def read_track(track_sid: str, stream: rtc.AudioStream):
|
||||
"""Continuously read frames from one participant's audio stream."""
|
||||
try:
|
||||
async for event in stream:
|
||||
frame = event.frame
|
||||
# Convert int16 PCM to float32 [-1.0, 1.0]
|
||||
pcm = np.frombuffer(frame.data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
async with lock:
|
||||
latest_frames[track_sid] = (time.monotonic(), pcm)
|
||||
except Exception as e:
|
||||
logger.warning("Stream read error for %s: %s", track_sid, e)
|
||||
finally:
|
||||
async with lock:
|
||||
latest_frames.pop(track_sid, None)
|
||||
audio_streams.pop(track_sid, None)
|
||||
logger.info("Stream ended: %s (active: %d)", track_sid, len(audio_streams))
|
||||
|
||||
@room.on("track_subscribed")
|
||||
def on_track_subscribed(
|
||||
subscribed_track: rtc.Track,
|
||||
publication: rtc.RemoteTrackPublication,
|
||||
participant: rtc.RemoteParticipant,
|
||||
):
|
||||
if subscribed_track.kind != rtc.TrackKind.KIND_AUDIO:
|
||||
return
|
||||
if len(audio_streams) >= MAX_STREAMS:
|
||||
logger.info("At max streams (%d), ignoring track from %s",
|
||||
MAX_STREAMS, participant.identity)
|
||||
return
|
||||
|
||||
sid = subscribed_track.sid
|
||||
stream = rtc.AudioStream(
|
||||
subscribed_track,
|
||||
sample_rate=SAMPLE_RATE,
|
||||
num_channels=NUM_CHANNELS,
|
||||
)
|
||||
audio_streams[sid] = stream
|
||||
asyncio.create_task(read_track(sid, stream))
|
||||
logger.info("Subscribed to %s from %s (active: %d)",
|
||||
sid, participant.identity, len(audio_streams))
|
||||
|
||||
@room.on("track_unsubscribed")
|
||||
def on_track_unsubscribed(
|
||||
unsubscribed_track: rtc.Track,
|
||||
publication: rtc.RemoteTrackPublication,
|
||||
participant: rtc.RemoteParticipant,
|
||||
):
|
||||
sid = unsubscribed_track.sid
|
||||
stream = audio_streams.get(sid)
|
||||
if stream:
|
||||
asyncio.create_task(stream.aclose())
|
||||
logger.info("Unsubscribed from %s (%s)", sid, participant.identity)
|
||||
|
||||
# Mixing loop: runs at frame rate (~20ms intervals)
|
||||
async def mixing_loop():
|
||||
while True:
|
||||
now = time.monotonic()
|
||||
async with lock:
|
||||
# Only use frames that arrived recently (discard stale ones)
|
||||
frames = [
|
||||
pcm for ts, pcm in latest_frames.values()
|
||||
if now - ts < FRAME_MAX_AGE_S
|
||||
]
|
||||
|
||||
if frames:
|
||||
# Normalize each stream, mix, and limit
|
||||
normalized = [normalize_rms(f, TARGET_DBFS) for f in frames]
|
||||
mixed = mix_streams(normalized)
|
||||
limited = soft_limit(mixed)
|
||||
else:
|
||||
limited = np.zeros(SAMPLES_PER_FRAME, dtype=np.float32)
|
||||
|
||||
# Convert float32 back to int16 PCM
|
||||
pcm_int16 = (limited * 32767).astype(np.int16)
|
||||
audio_frame = rtc.AudioFrame(
|
||||
data=pcm_int16.tobytes(),
|
||||
sample_rate=SAMPLE_RATE,
|
||||
num_channels=NUM_CHANNELS,
|
||||
samples_per_channel=len(pcm_int16),
|
||||
)
|
||||
await source.capture_frame(audio_frame)
|
||||
|
||||
await asyncio.sleep(FRAME_DURATION_MS / 1000)
|
||||
|
||||
# Start mixing loop
|
||||
mix_task = asyncio.create_task(mixing_loop())
|
||||
|
||||
# Keep agent alive until mixing_loop exits (room disconnect cancels it)
|
||||
try:
|
||||
await mix_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
for stream in list(audio_streams.values()):
|
||||
await stream.aclose()
|
||||
logger.info("Choir mixer exiting room: %s", room.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
agents.cli.run_app(server)
|
||||
Reference in New Issue
Block a user