142 lines
5.1 KiB
Python
142 lines
5.1 KiB
Python
import asyncio
|
|
import logging
|
|
import time
|
|
import numpy as np
|
|
from livekit import agents, rtc
|
|
from mixer import normalize_rms, mix_streams, soft_limit
|
|
|
|
logger = logging.getLogger("choir-mixer")
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
SAMPLE_RATE = 48000
|
|
NUM_CHANNELS = 1
|
|
FRAME_DURATION_MS = 20
|
|
SAMPLES_PER_FRAME = SAMPLE_RATE * FRAME_DURATION_MS // 1000 # 960
|
|
MAX_STREAMS = 6
|
|
TARGET_DBFS = -20.0
|
|
FRAME_MAX_AGE_S = 0.06 # Only use frames received in the last 60ms
|
|
|
|
server = agents.AgentServer()
|
|
|
|
|
|
@server.rtc_session(agent_name="choir-mixer")
|
|
async def choir_mixer_agent(ctx: agents.JobContext):
|
|
"""Choir mixer agent: subscribes to participant audio, mixes, and publishes."""
|
|
room = ctx.room
|
|
|
|
# Track active audio streams: {track_sid: AudioStream}
|
|
audio_streams: dict[str, rtc.AudioStream] = {}
|
|
# Latest frame per track: {track_sid: (timestamp, np.ndarray)}
|
|
latest_frames: dict[str, tuple[float, np.ndarray]] = {}
|
|
lock = asyncio.Lock()
|
|
|
|
# Set up audio output
|
|
source = rtc.AudioSource(SAMPLE_RATE, NUM_CHANNELS)
|
|
track = rtc.LocalAudioTrack.create_audio_track("choir_mix", source)
|
|
options = rtc.TrackPublishOptions()
|
|
options.source = rtc.TrackSource.SOURCE_MICROPHONE
|
|
await ctx.connect(auto_subscribe=agents.AutoSubscribe.AUDIO_ONLY)
|
|
await room.local_participant.publish_track(track, options)
|
|
logger.info("Choir mixer joined room: %s", room.name)
|
|
|
|
async def read_track(track_sid: str, stream: rtc.AudioStream):
|
|
"""Continuously read frames from one participant's audio stream."""
|
|
try:
|
|
async for event in stream:
|
|
frame = event.frame
|
|
# Convert int16 PCM to float32 [-1.0, 1.0]
|
|
pcm = np.frombuffer(frame.data, dtype=np.int16).astype(np.float32) / 32768.0
|
|
async with lock:
|
|
latest_frames[track_sid] = (time.monotonic(), pcm)
|
|
except Exception as e:
|
|
logger.warning("Stream read error for %s: %s", track_sid, e)
|
|
finally:
|
|
async with lock:
|
|
latest_frames.pop(track_sid, None)
|
|
audio_streams.pop(track_sid, None)
|
|
logger.info("Stream ended: %s (active: %d)", track_sid, len(audio_streams))
|
|
|
|
@room.on("track_subscribed")
|
|
def on_track_subscribed(
|
|
subscribed_track: rtc.Track,
|
|
publication: rtc.RemoteTrackPublication,
|
|
participant: rtc.RemoteParticipant,
|
|
):
|
|
if subscribed_track.kind != rtc.TrackKind.KIND_AUDIO:
|
|
return
|
|
if len(audio_streams) >= MAX_STREAMS:
|
|
logger.info("At max streams (%d), ignoring track from %s",
|
|
MAX_STREAMS, participant.identity)
|
|
return
|
|
|
|
sid = subscribed_track.sid
|
|
stream = rtc.AudioStream(
|
|
subscribed_track,
|
|
sample_rate=SAMPLE_RATE,
|
|
num_channels=NUM_CHANNELS,
|
|
)
|
|
audio_streams[sid] = stream
|
|
asyncio.create_task(read_track(sid, stream))
|
|
logger.info("Subscribed to %s from %s (active: %d)",
|
|
sid, participant.identity, len(audio_streams))
|
|
|
|
@room.on("track_unsubscribed")
|
|
def on_track_unsubscribed(
|
|
unsubscribed_track: rtc.Track,
|
|
publication: rtc.RemoteTrackPublication,
|
|
participant: rtc.RemoteParticipant,
|
|
):
|
|
sid = unsubscribed_track.sid
|
|
stream = audio_streams.get(sid)
|
|
if stream:
|
|
asyncio.create_task(stream.aclose())
|
|
logger.info("Unsubscribed from %s (%s)", sid, participant.identity)
|
|
|
|
# Mixing loop: runs at frame rate (~20ms intervals)
|
|
async def mixing_loop():
|
|
while True:
|
|
now = time.monotonic()
|
|
async with lock:
|
|
# Only use frames that arrived recently (discard stale ones)
|
|
frames = [
|
|
pcm for ts, pcm in latest_frames.values()
|
|
if now - ts < FRAME_MAX_AGE_S
|
|
]
|
|
|
|
if frames:
|
|
# Normalize each stream, mix, and limit
|
|
normalized = [normalize_rms(f, TARGET_DBFS) for f in frames]
|
|
mixed = mix_streams(normalized)
|
|
limited = soft_limit(mixed)
|
|
else:
|
|
limited = np.zeros(SAMPLES_PER_FRAME, dtype=np.float32)
|
|
|
|
# Convert float32 back to int16 PCM
|
|
pcm_int16 = (limited * 32767).astype(np.int16)
|
|
audio_frame = rtc.AudioFrame(
|
|
data=pcm_int16.tobytes(),
|
|
sample_rate=SAMPLE_RATE,
|
|
num_channels=NUM_CHANNELS,
|
|
samples_per_channel=len(pcm_int16),
|
|
)
|
|
await source.capture_frame(audio_frame)
|
|
|
|
await asyncio.sleep(FRAME_DURATION_MS / 1000)
|
|
|
|
# Start mixing loop
|
|
mix_task = asyncio.create_task(mixing_loop())
|
|
|
|
# Keep agent alive until mixing_loop exits (room disconnect cancels it)
|
|
try:
|
|
await mix_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
for stream in list(audio_streams.values()):
|
|
await stream.aclose()
|
|
logger.info("Choir mixer exiting room: %s", room.name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
agents.cli.run_app(server)
|