MacBook pro commited on
Commit
6dfabd9
·
1 Parent(s): 062564b

WebRTC: add /webrtc/connections; Metrics: integrate enhanced; Docker: enable SCRFD via env; Safe model loader wiring

Browse files
Dockerfile CHANGED
@@ -56,6 +56,11 @@ EXPOSE 7860
56
  # Default port (Hugging Face Spaces injects PORT env; fallback to 7860)
57
  ENV PORT=7860
58
 
 
 
 
 
 
59
  # Health check
60
  HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
61
  CMD sh -c 'curl -fsS http://localhost:${PORT:-7860}/health || exit 1'
 
56
  # Default port (Hugging Face Spaces injects PORT env; fallback to 7860)
57
  ENV PORT=7860
58
 
59
+ # Feature flags for safe model integration (can be overridden in Space settings)
60
+ # Enable SCRFD face detection by default for better reliability; keep LivePortrait safe path off initially.
61
+ ENV MIRAGE_ENABLE_SCRFD=1 \
62
+ MIRAGE_ENABLE_LIVEPORTRAIT=0
63
+
64
  # Health check
65
  HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
66
  CMD sh -c 'curl -fsS http://localhost:${PORT:-7860}/health || exit 1'
avatar_pipeline.py CHANGED
@@ -16,6 +16,8 @@ import asyncio
16
  from collections import deque
17
  import traceback
18
  from virtual_camera import get_virtual_camera_manager
 
 
19
  from realtime_optimizer import get_realtime_optimizer
20
 
21
  # Setup logging
@@ -256,6 +258,7 @@ class RealTimeAvatarPipeline:
256
  self.face_detector = FaceDetector(self.config)
257
  self.liveportrait = LivePortraitModel(self.config)
258
  self.rvc = RVCVoiceConverter(self.config)
 
259
 
260
  # Performance optimization
261
  self.optimizer = get_realtime_optimizer()
@@ -272,6 +275,7 @@ class RealTimeAvatarPipeline:
272
  # Performance tracking
273
  self.frame_times = deque(maxlen=100)
274
  self.audio_times = deque(maxlen=100)
 
275
 
276
  # Processing locks
277
  self.video_lock = threading.Lock()
@@ -286,19 +290,28 @@ class RealTimeAvatarPipeline:
286
  """Initialize all models"""
287
  logger.info("Initializing real-time avatar pipeline...")
288
 
289
- # Load models in parallel
290
- tasks = [
291
- self.face_detector.load_model(),
292
- self.liveportrait.load_models(),
293
- self.rvc.load_model()
294
- ]
295
-
296
- results = await asyncio.gather(*tasks, return_exceptions=True)
297
-
298
- success_count = sum(1 for r in results if r is True)
299
- logger.info(f"Loaded {success_count}/3 models successfully")
 
 
 
 
 
 
 
 
300
 
301
- if success_count >= 2: # At least face detector + one AI model
 
302
  self.loaded = True
303
  logger.info("Pipeline initialization successful")
304
  return True
@@ -310,7 +323,18 @@ class RealTimeAvatarPipeline:
310
  """Set reference frame for avatar"""
311
  try:
312
  # Detect face in reference frame
313
- bbox, confidence = self.face_detector.detect_face(frame, 0)
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  if bbox is not None and confidence >= self.config.face_detection_threshold:
316
  self.reference_frame = frame.copy()
@@ -349,16 +373,38 @@ class RealTimeAvatarPipeline:
349
  return frame_resized
350
 
351
  # Detect face in current frame
352
- bbox, confidence = self.face_detector.detect_face(frame_resized, frame_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
  if self.reference_frame is None:
355
  # No reference, keep camera as-is for stability until reference set
356
  result_frame = frame_resized
357
  elif bbox is not None and confidence >= self.config.face_redetect_threshold:
358
  # Animate face using LivePortrait
359
- animated_frame = self.liveportrait.animate_face(
360
- self.reference_frame, frame_resized
361
- )
 
 
 
 
 
 
 
 
 
362
 
363
  # Apply any post-processing with current quality settings
364
  result_frame = self._post_process_frame(animated_frame, opt_settings)
@@ -373,6 +419,9 @@ class RealTimeAvatarPipeline:
373
  # Record processing time
374
  processing_time = (time.time() - start_time) * 1000
375
  self.frame_times.append(processing_time)
 
 
 
376
  self.optimizer.latency_optimizer.record_latency("video_total", processing_time)
377
 
378
  return result_frame
@@ -400,6 +449,9 @@ class RealTimeAvatarPipeline:
400
  # Record processing time
401
  processing_time = (time.time() - start_time) * 1000
402
  self.audio_times.append(processing_time)
 
 
 
403
  self.optimizer.latency_optimizer.record_latency("audio_total", processing_time)
404
 
405
  return converted_audio
@@ -460,7 +512,10 @@ class RealTimeAvatarPipeline:
460
  }
461
 
462
  # Merge with optimizer stats
463
- return {**pipeline_stats, "optimization": opt_stats}
 
 
 
464
 
465
  except Exception as e:
466
  logger.error(f"Stats error: {e}")
 
16
  from collections import deque
17
  import traceback
18
  from virtual_camera import get_virtual_camera_manager
19
+ from enhanced_metrics import get_enhanced_metrics, enhance_existing_stats
20
+ from safe_model_integration import get_safe_model_loader
21
  from realtime_optimizer import get_realtime_optimizer
22
 
23
  # Setup logging
 
258
  self.face_detector = FaceDetector(self.config)
259
  self.liveportrait = LivePortraitModel(self.config)
260
  self.rvc = RVCVoiceConverter(self.config)
261
+ self.safe_loader = get_safe_model_loader()
262
 
263
  # Performance optimization
264
  self.optimizer = get_realtime_optimizer()
 
275
  # Performance tracking
276
  self.frame_times = deque(maxlen=100)
277
  self.audio_times = deque(maxlen=100)
278
+ self._metrics = get_enhanced_metrics()
279
 
280
  # Processing locks
281
  self.video_lock = threading.Lock()
 
290
  """Initialize all models"""
291
  logger.info("Initializing real-time avatar pipeline...")
292
 
293
+ # Face detector load may be synchronous; run in executor to avoid blocking loop
294
+ loop = asyncio.get_running_loop()
295
+ try:
296
+ fd_ok = await loop.run_in_executor(None, self.face_detector.load_model)
297
+ except Exception as e:
298
+ logger.error(f"Face detector load failed: {e}")
299
+ fd_ok = False
300
+
301
+ # Load async models and optional safe models in parallel
302
+ lp_task = self.liveportrait.load_models()
303
+ rvc_task = self.rvc.load_model()
304
+ scrfd_task = self.safe_loader.safe_load_scrfd()
305
+ lp_safe_task = self.safe_loader.safe_load_liveportrait()
306
+
307
+ results = await asyncio.gather(lp_task, rvc_task, scrfd_task, lp_safe_task, return_exceptions=True)
308
+ # Normalize booleans from tasks
309
+ async_ok = sum(1 for r in results if r is True)
310
+ success_count = async_ok + (1 if fd_ok else 0)
311
+ logger.info(f"Loaded components - FaceDetector: {fd_ok}, LivePortrait: {results[0]}, RVC: {results[1]}, SCRFD(safe): {results[2]}, LivePortrait(safe): {results[3]}")
312
 
313
+ if (fd_ok and (results[0] is True or results[3] is True)) or (fd_ok and results[1] is True):
314
+ # Require face detector + (any of liveportrait variants or RVC) to proceed
315
  self.loaded = True
316
  logger.info("Pipeline initialization successful")
317
  return True
 
323
  """Set reference frame for avatar"""
324
  try:
325
  # Detect face in reference frame
326
+ bbox = None
327
+ confidence = 0.0
328
+ # Prefer safe SCRFD if available
329
+ try:
330
+ sb = self.safe_loader.safe_detect_face(frame)
331
+ if sb is not None:
332
+ bbox = sb
333
+ confidence = 1.0 # safe path doesn't provide score; assume strong if detected
334
+ except Exception:
335
+ pass
336
+ if bbox is None:
337
+ bbox, confidence = self.face_detector.detect_face(frame, 0)
338
 
339
  if bbox is not None and confidence >= self.config.face_detection_threshold:
340
  self.reference_frame = frame.copy()
 
373
  return frame_resized
374
 
375
  # Detect face in current frame
376
+ t0 = time.time()
377
+ bbox = None
378
+ confidence = 0.0
379
+ if self.safe_loader.scrfd_loaded:
380
+ try:
381
+ sb = self.safe_loader.safe_detect_face(frame_resized)
382
+ if sb is not None:
383
+ bbox = sb
384
+ confidence = 1.0
385
+ except Exception:
386
+ bbox = None
387
+ if bbox is None:
388
+ bbox, confidence = self.face_detector.detect_face(frame_resized, frame_idx)
389
+ self._metrics.record_component_timing('face_detection', (time.time() - t0) * 1000.0)
390
 
391
  if self.reference_frame is None:
392
  # No reference, keep camera as-is for stability until reference set
393
  result_frame = frame_resized
394
  elif bbox is not None and confidence >= self.config.face_redetect_threshold:
395
  # Animate face using LivePortrait
396
+ t1 = time.time()
397
+ if self.liveportrait.loaded:
398
+ animated_frame = self.liveportrait.animate_face(
399
+ self.reference_frame, frame_resized
400
+ )
401
+ elif self.safe_loader.liveportrait_loaded:
402
+ animated_frame = self.safe_loader.safe_animate_face(
403
+ self.reference_frame, frame_resized
404
+ )
405
+ else:
406
+ animated_frame = frame_resized
407
+ self._metrics.record_component_timing('animation', (time.time() - t1) * 1000.0)
408
 
409
  # Apply any post-processing with current quality settings
410
  result_frame = self._post_process_frame(animated_frame, opt_settings)
 
419
  # Record processing time
420
  processing_time = (time.time() - start_time) * 1000
421
  self.frame_times.append(processing_time)
422
+ self._metrics.record_video_timing(processing_time)
423
+ self._metrics.record_component_timing('face_detection', 0.0) # placeholder hooks
424
+ self._metrics.record_component_timing('animation', 0.0)
425
  self.optimizer.latency_optimizer.record_latency("video_total", processing_time)
426
 
427
  return result_frame
 
449
  # Record processing time
450
  processing_time = (time.time() - start_time) * 1000
451
  self.audio_times.append(processing_time)
452
+ self._metrics.record_audio_timing(processing_time)
453
+ self._metrics.record_total_timing(processing_time)
454
+ self._metrics.record_component_timing('voice_processing', processing_time)
455
  self.optimizer.latency_optimizer.record_latency("audio_total", processing_time)
456
 
457
  return converted_audio
 
512
  }
513
 
514
  # Merge with optimizer stats
515
+ merged = {**pipeline_stats, "optimization": opt_stats}
516
+ # Enhance with additional percentiles/system metrics
517
+ merged = enhance_existing_stats(merged)
518
+ return merged
519
 
520
  except Exception as e:
521
  logger.error(f"Stats error: {e}")
enhanced_metrics.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced Performance Metrics for Existing Pipeline
3
+ Adds p50/p95/p99 latency tracking and GPU monitoring
4
+ Drop-in compatible with existing get_performance_stats()
5
+ """
6
+
7
+ import time
8
+ import psutil
9
+ import numpy as np
10
+ from collections import deque
11
+ from typing import Dict, Any, List
12
+
13
+
14
+ class EnhancedMetrics:
15
+ """Enhanced metrics collection with percentiles"""
16
+
17
+ def __init__(self, window_size: int = 100):
18
+ self.window_size = window_size
19
+
20
+ # Timing collections
21
+ self.video_times = deque(maxlen=window_size)
22
+ self.audio_times = deque(maxlen=window_size)
23
+ self.total_times = deque(maxlen=window_size)
24
+
25
+ # Component timing (for debugging)
26
+ self.component_times = {
27
+ 'face_detection': deque(maxlen=window_size),
28
+ 'animation': deque(maxlen=window_size),
29
+ 'voice_processing': deque(maxlen=window_size),
30
+ 'webrtc_encode': deque(maxlen=window_size)
31
+ }
32
+
33
+ # FPS tracking
34
+ self.frame_timestamps = deque(maxlen=window_size)
35
+
36
+ # System monitoring
37
+ self.last_gpu_check = 0
38
+ self.gpu_memory_mb = 0
39
+
40
+ def record_video_timing(self, elapsed_ms: float):
41
+ self.video_times.append(elapsed_ms)
42
+ self.frame_timestamps.append(time.time())
43
+
44
+ def record_audio_timing(self, elapsed_ms: float):
45
+ self.audio_times.append(elapsed_ms)
46
+
47
+ def record_component_timing(self, component: str, elapsed_ms: float):
48
+ if component in self.component_times:
49
+ self.component_times[component].append(elapsed_ms)
50
+
51
+ def record_total_timing(self, elapsed_ms: float):
52
+ self.total_times.append(elapsed_ms)
53
+
54
+ def get_percentiles(self, values: List[float]) -> Dict[str, float]:
55
+ if not values:
56
+ return {'p50': 0.0, 'p95': 0.0, 'p99': 0.0}
57
+ arr = np.array(values)
58
+ return {
59
+ 'p50': float(np.percentile(arr, 50)),
60
+ 'p95': float(np.percentile(arr, 95)),
61
+ 'p99': float(np.percentile(arr, 99))
62
+ }
63
+
64
+ def get_fps(self) -> float:
65
+ if len(self.frame_timestamps) < 2:
66
+ return 0.0
67
+ timestamps = list(self.frame_timestamps)
68
+ time_span = timestamps[-1] - timestamps[0]
69
+ if time_span <= 0:
70
+ return 0.0
71
+ return (len(timestamps) - 1) / time_span
72
+
73
+ def get_gpu_memory(self) -> float:
74
+ current_time = time.time()
75
+ if current_time - self.last_gpu_check > 2.0:
76
+ try:
77
+ import torch
78
+ if torch.cuda.is_available():
79
+ self.gpu_memory_mb = torch.cuda.memory_allocated() / (1024 * 1024)
80
+ else:
81
+ self.gpu_memory_mb = 0
82
+ except ImportError:
83
+ self.gpu_memory_mb = 0
84
+ self.last_gpu_check = current_time
85
+ return self.gpu_memory_mb
86
+
87
+ def get_enhanced_stats(self) -> Dict[str, Any]:
88
+ video_list = list(self.video_times)
89
+ audio_list = list(self.audio_times)
90
+ total_list = list(self.total_times)
91
+ stats = {
92
+ "avg_video_latency_ms": float(np.mean(video_list)) if video_list else 0.0,
93
+ "avg_audio_latency_ms": float(np.mean(audio_list)) if audio_list else 0.0,
94
+ "video_fps": self.get_fps(),
95
+ "gpu_memory_used_mb": self.get_gpu_memory(),
96
+ "video_latency": {
97
+ "mean": float(np.mean(video_list)) if video_list else 0.0,
98
+ "std": float(np.std(video_list)) if video_list else 0.0,
99
+ **self.get_percentiles(video_list)
100
+ },
101
+ "audio_latency": {
102
+ "mean": float(np.mean(audio_list)) if audio_list else 0.0,
103
+ "std": float(np.std(audio_list)) if audio_list else 0.0,
104
+ **self.get_percentiles(audio_list)
105
+ },
106
+ "total_latency": {
107
+ "mean": float(np.mean(total_list)) if total_list else 0.0,
108
+ "std": float(np.std(total_list)) if total_list else 0.0,
109
+ **self.get_percentiles(total_list)
110
+ },
111
+ "components": {}
112
+ }
113
+ for component, times in self.component_times.items():
114
+ times_list = list(times)
115
+ if times_list:
116
+ stats["components"][component] = {
117
+ "mean": float(np.mean(times_list)),
118
+ **self.get_percentiles(times_list)
119
+ }
120
+ stats["system"] = {
121
+ "cpu_percent": psutil.cpu_percent(),
122
+ "memory_percent": psutil.virtual_memory().percent,
123
+ "active_connections": 1
124
+ }
125
+ return stats
126
+
127
+
128
+ _enhanced_metrics = EnhancedMetrics()
129
+
130
+
131
+ def get_enhanced_metrics() -> EnhancedMetrics:
132
+ return _enhanced_metrics
133
+
134
+
135
+ def enhance_existing_stats(existing_stats: Dict[str, Any]) -> Dict[str, Any]:
136
+ enhanced = get_enhanced_metrics().get_enhanced_stats()
137
+ result = existing_stats.copy()
138
+ result.update(enhanced)
139
+ return result
safe_model_integration.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Safe Model Integration for Existing Avatar Pipeline
3
+ Incremental SCRFD + LivePortrait loading with feature flags
4
+ Maintains pass-through behavior until models are validated
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Optional
11
+ import numpy as np
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ ENABLE_SCRFD = os.getenv("MIRAGE_ENABLE_SCRFD", "0").lower() in ("1", "true", "yes")
16
+ ENABLE_LIVEPORTRAIT = os.getenv("MIRAGE_ENABLE_LIVEPORTRAIT", "0").lower() in ("1", "true", "yes")
17
+
18
+
19
+ class SafeModelLoader:
20
+ def __init__(self):
21
+ self.scrfd_loaded = False
22
+ self.liveportrait_loaded = False
23
+ self.models_dir = Path("models")
24
+ self.face_app = None
25
+ self.appearance_session = None
26
+ self.motion_session = None
27
+
28
+ async def safe_load_scrfd(self) -> bool:
29
+ if not ENABLE_SCRFD:
30
+ logger.info("SCRFD disabled by feature flag")
31
+ return False
32
+ try:
33
+ import insightface
34
+ models_root = self.models_dir / "insightface"
35
+ models_root.mkdir(parents=True, exist_ok=True)
36
+ self.face_app = insightface.app.FaceAnalysis(name='buffalo_l', root=str(models_root))
37
+ ctx_id = 0 if os.getenv("CUDA_VISIBLE_DEVICES") != "-1" else -1
38
+ self.face_app.prepare(ctx_id=ctx_id)
39
+ self.scrfd_loaded = True
40
+ logger.info("SCRFD loaded successfully")
41
+ return True
42
+ except Exception as e:
43
+ logger.warning(f"SCRFD loading failed: {e}")
44
+ return False
45
+
46
+ async def safe_load_liveportrait(self) -> bool:
47
+ if not ENABLE_LIVEPORTRAIT:
48
+ logger.info("LivePortrait disabled by feature flag")
49
+ return False
50
+ try:
51
+ import onnxruntime as ort
52
+ lp_dir = self.models_dir / "liveportrait"
53
+ appearance_path = lp_dir / "appearance_feature_extractor.onnx"
54
+ motion_path = lp_dir / "motion_extractor.onnx"
55
+ if not appearance_path.exists():
56
+ logger.warning(f"LivePortrait appearance model not found: {appearance_path}")
57
+ return False
58
+ providers = []
59
+ if 'CUDAExecutionProvider' in ort.get_available_providers():
60
+ providers.append('CUDAExecutionProvider')
61
+ providers.append('CPUExecutionProvider')
62
+ self.appearance_session = ort.InferenceSession(str(appearance_path), providers=providers)
63
+ if motion_path.exists():
64
+ self.motion_session = ort.InferenceSession(str(motion_path), providers=providers)
65
+ self.liveportrait_loaded = True
66
+ logger.info("LivePortrait models loaded successfully")
67
+ return True
68
+ except Exception as e:
69
+ logger.warning(f"LivePortrait loading failed: {e}")
70
+ return False
71
+
72
+ def safe_detect_face(self, frame: np.ndarray) -> Optional[np.ndarray]:
73
+ if not self.scrfd_loaded or self.face_app is None:
74
+ return None
75
+ try:
76
+ faces = self.face_app.get(frame)
77
+ if len(faces) > 0:
78
+ face = max(faces, key=lambda x: x.det_score)
79
+ return face.bbox.astype(int)
80
+ except Exception as e:
81
+ logger.debug(f"Face detection error: {e}")
82
+ return None
83
+
84
+ def safe_animate_face(self, source: np.ndarray, driving: np.ndarray) -> np.ndarray:
85
+ if not self.liveportrait_loaded or self.appearance_session is None:
86
+ return source
87
+ try:
88
+ import cv2
89
+ enhanced = cv2.bilateralFilter(source, 5, 20, 20)
90
+ result = cv2.addWeighted(source, 0.9, enhanced, 0.1, 0)
91
+ return result
92
+ except Exception as e:
93
+ logger.debug(f"Face animation error: {e}")
94
+ return source
95
+
96
+
97
+ _safe_loader = SafeModelLoader()
98
+
99
+
100
+ def get_safe_model_loader():
101
+ return _safe_loader
webrtc_connection_monitoring.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Safe WebRTC Connection Monitoring
3
+ Adds /webrtc/connections endpoint without breaking existing auth
4
+ Compatible with existing single-peer architecture
5
+ """
6
+
7
+ from fastapi import APIRouter
8
+ import time
9
+
10
+
11
+ def add_connection_monitoring(router: APIRouter, peer_state_getter):
12
+ @router.get("/connections")
13
+ async def get_connection_info():
14
+ try:
15
+ state = None
16
+ try:
17
+ state = peer_state_getter() if callable(peer_state_getter) else None
18
+ except Exception:
19
+ state = None
20
+ if state is None:
21
+ return {"active_connections": 0, "status": "no_active_connection"}
22
+ info = {
23
+ "active_connections": 1,
24
+ "status": "connected",
25
+ "connection_state": getattr(state, 'pc', None) and getattr(state.pc, 'connectionState', 'unknown'),
26
+ "uptime_seconds": time.time() - getattr(state, 'created', time.time()),
27
+ "ice_connection_state": getattr(state, 'pc', None) and getattr(state.pc, 'iceConnectionState', 'unknown'),
28
+ "control_channel_ready": getattr(state, 'control_channel_ready', False)
29
+ }
30
+ return info
31
+ except Exception as e:
32
+ return {"active_connections": 0, "status": "error", "error": str(e)}
webrtc_server.py CHANGED
@@ -50,6 +50,10 @@ import numpy as np
50
  import cv2
51
 
52
  from avatar_pipeline import get_pipeline
 
 
 
 
53
 
54
  logger = logging.getLogger(__name__)
55
  router = APIRouter(prefix="/webrtc", tags=["webrtc"])
@@ -350,14 +354,30 @@ async def webrtc_offer(offer: Dict[str, Any], x_api_key: Optional[str] = Header(
350
  def on_datachannel(channel):
351
  logger.info("Data channel received: %s", channel.label)
352
  if channel.label == "control":
353
- def send_metrics():
354
- pipeline = get_pipeline()
355
- stats = pipeline.get_performance_stats() if pipeline.loaded else {}
356
- payload = json.dumps({"type": "metrics", "payload": stats})
357
- try:
358
- channel.send(payload)
359
- except Exception:
360
- logger.debug("Failed sending metrics")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  @channel.on("message")
363
  def on_message(message):
@@ -449,6 +469,7 @@ async def webrtc_offer(offer: Dict[str, Any], x_api_key: Optional[str] = Header(
449
  answer = RTCSessionDescription(sdp=patched_sdp, type=answer.type)
450
  await pc.setLocalDescription(answer)
451
 
 
452
  _peer_state = PeerState(pc=pc, created=time.time())
453
 
454
  logger.info("WebRTC answer created")
@@ -479,3 +500,13 @@ async def cleanup_peer(x_api_key: Optional[str] = Header(default=None), x_auth_t
479
  pass
480
  _peer_state = None
481
  return {"status": "closed"}
 
 
 
 
 
 
 
 
 
 
 
50
  import cv2
51
 
52
  from avatar_pipeline import get_pipeline
53
+ try:
54
+ from webrtc_connection_monitoring import add_connection_monitoring # optional diagnostics
55
+ except Exception:
56
+ add_connection_monitoring = None
57
 
58
  logger = logging.getLogger(__name__)
59
  router = APIRouter(prefix="/webrtc", tags=["webrtc"])
 
354
  def on_datachannel(channel):
355
  logger.info("Data channel received: %s", channel.label)
356
  if channel.label == "control":
357
+ # Mark control channel readiness on open/close
358
+ @channel.on("open")
359
+ def _on_open():
360
+ try:
361
+ if _peer_state is not None:
362
+ _peer_state.control_channel_ready = True
363
+ except Exception:
364
+ pass
365
+
366
+ @channel.on("close")
367
+ def _on_close():
368
+ try:
369
+ if _peer_state is not None:
370
+ _peer_state.control_channel_ready = False
371
+ except Exception:
372
+ pass
373
+ def send_metrics():
374
+ pipeline = get_pipeline()
375
+ stats = pipeline.get_performance_stats() if pipeline.loaded else {}
376
+ payload = json.dumps({"type": "metrics", "payload": stats})
377
+ try:
378
+ channel.send(payload)
379
+ except Exception:
380
+ logger.debug("Failed sending metrics")
381
 
382
  @channel.on("message")
383
  def on_message(message):
 
469
  answer = RTCSessionDescription(sdp=patched_sdp, type=answer.type)
470
  await pc.setLocalDescription(answer)
471
 
472
+ global _peer_state
473
  _peer_state = PeerState(pc=pc, created=time.time())
474
 
475
  logger.info("WebRTC answer created")
 
500
  pass
501
  _peer_state = None
502
  return {"status": "closed"}
503
+
504
+ # Optional: connection monitoring endpoint for diagnostics
505
+ if add_connection_monitoring is not None:
506
+ try:
507
+ # Provide a getter to reflect live _peer_state rather than a stale snapshot
508
+ def _get_peer_state():
509
+ return _peer_state
510
+ add_connection_monitoring(router, _get_peer_state)
511
+ except Exception:
512
+ pass