Spaces:
Paused
Paused
MacBook pro
commited on
Commit
·
6dfabd9
1
Parent(s):
062564b
WebRTC: add /webrtc/connections; Metrics: integrate enhanced; Docker: enable SCRFD via env; Safe model loader wiring
Browse files- Dockerfile +5 -0
- avatar_pipeline.py +73 -18
- enhanced_metrics.py +139 -0
- safe_model_integration.py +101 -0
- webrtc_connection_monitoring.py +32 -0
- webrtc_server.py +39 -8
Dockerfile
CHANGED
|
@@ -56,6 +56,11 @@ EXPOSE 7860
|
|
| 56 |
# Default port (Hugging Face Spaces injects PORT env; fallback to 7860)
|
| 57 |
ENV PORT=7860
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# Health check
|
| 60 |
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 61 |
CMD sh -c 'curl -fsS http://localhost:${PORT:-7860}/health || exit 1'
|
|
|
|
| 56 |
# Default port (Hugging Face Spaces injects PORT env; fallback to 7860)
|
| 57 |
ENV PORT=7860
|
| 58 |
|
| 59 |
+
# Feature flags for safe model integration (can be overridden in Space settings)
|
| 60 |
+
# Enable SCRFD face detection by default for better reliability; keep LivePortrait safe path off initially.
|
| 61 |
+
ENV MIRAGE_ENABLE_SCRFD=1 \
|
| 62 |
+
MIRAGE_ENABLE_LIVEPORTRAIT=0
|
| 63 |
+
|
| 64 |
# Health check
|
| 65 |
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 66 |
CMD sh -c 'curl -fsS http://localhost:${PORT:-7860}/health || exit 1'
|
avatar_pipeline.py
CHANGED
|
@@ -16,6 +16,8 @@ import asyncio
|
|
| 16 |
from collections import deque
|
| 17 |
import traceback
|
| 18 |
from virtual_camera import get_virtual_camera_manager
|
|
|
|
|
|
|
| 19 |
from realtime_optimizer import get_realtime_optimizer
|
| 20 |
|
| 21 |
# Setup logging
|
|
@@ -256,6 +258,7 @@ class RealTimeAvatarPipeline:
|
|
| 256 |
self.face_detector = FaceDetector(self.config)
|
| 257 |
self.liveportrait = LivePortraitModel(self.config)
|
| 258 |
self.rvc = RVCVoiceConverter(self.config)
|
|
|
|
| 259 |
|
| 260 |
# Performance optimization
|
| 261 |
self.optimizer = get_realtime_optimizer()
|
|
@@ -272,6 +275,7 @@ class RealTimeAvatarPipeline:
|
|
| 272 |
# Performance tracking
|
| 273 |
self.frame_times = deque(maxlen=100)
|
| 274 |
self.audio_times = deque(maxlen=100)
|
|
|
|
| 275 |
|
| 276 |
# Processing locks
|
| 277 |
self.video_lock = threading.Lock()
|
|
@@ -286,19 +290,28 @@ class RealTimeAvatarPipeline:
|
|
| 286 |
"""Initialize all models"""
|
| 287 |
logger.info("Initializing real-time avatar pipeline...")
|
| 288 |
|
| 289 |
-
#
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
self.
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
-
if
|
|
|
|
| 302 |
self.loaded = True
|
| 303 |
logger.info("Pipeline initialization successful")
|
| 304 |
return True
|
|
@@ -310,7 +323,18 @@ class RealTimeAvatarPipeline:
|
|
| 310 |
"""Set reference frame for avatar"""
|
| 311 |
try:
|
| 312 |
# Detect face in reference frame
|
| 313 |
-
bbox
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
if bbox is not None and confidence >= self.config.face_detection_threshold:
|
| 316 |
self.reference_frame = frame.copy()
|
|
@@ -349,16 +373,38 @@ class RealTimeAvatarPipeline:
|
|
| 349 |
return frame_resized
|
| 350 |
|
| 351 |
# Detect face in current frame
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
if self.reference_frame is None:
|
| 355 |
# No reference, keep camera as-is for stability until reference set
|
| 356 |
result_frame = frame_resized
|
| 357 |
elif bbox is not None and confidence >= self.config.face_redetect_threshold:
|
| 358 |
# Animate face using LivePortrait
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
# Apply any post-processing with current quality settings
|
| 364 |
result_frame = self._post_process_frame(animated_frame, opt_settings)
|
|
@@ -373,6 +419,9 @@ class RealTimeAvatarPipeline:
|
|
| 373 |
# Record processing time
|
| 374 |
processing_time = (time.time() - start_time) * 1000
|
| 375 |
self.frame_times.append(processing_time)
|
|
|
|
|
|
|
|
|
|
| 376 |
self.optimizer.latency_optimizer.record_latency("video_total", processing_time)
|
| 377 |
|
| 378 |
return result_frame
|
|
@@ -400,6 +449,9 @@ class RealTimeAvatarPipeline:
|
|
| 400 |
# Record processing time
|
| 401 |
processing_time = (time.time() - start_time) * 1000
|
| 402 |
self.audio_times.append(processing_time)
|
|
|
|
|
|
|
|
|
|
| 403 |
self.optimizer.latency_optimizer.record_latency("audio_total", processing_time)
|
| 404 |
|
| 405 |
return converted_audio
|
|
@@ -460,7 +512,10 @@ class RealTimeAvatarPipeline:
|
|
| 460 |
}
|
| 461 |
|
| 462 |
# Merge with optimizer stats
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
except Exception as e:
|
| 466 |
logger.error(f"Stats error: {e}")
|
|
|
|
| 16 |
from collections import deque
|
| 17 |
import traceback
|
| 18 |
from virtual_camera import get_virtual_camera_manager
|
| 19 |
+
from enhanced_metrics import get_enhanced_metrics, enhance_existing_stats
|
| 20 |
+
from safe_model_integration import get_safe_model_loader
|
| 21 |
from realtime_optimizer import get_realtime_optimizer
|
| 22 |
|
| 23 |
# Setup logging
|
|
|
|
| 258 |
self.face_detector = FaceDetector(self.config)
|
| 259 |
self.liveportrait = LivePortraitModel(self.config)
|
| 260 |
self.rvc = RVCVoiceConverter(self.config)
|
| 261 |
+
self.safe_loader = get_safe_model_loader()
|
| 262 |
|
| 263 |
# Performance optimization
|
| 264 |
self.optimizer = get_realtime_optimizer()
|
|
|
|
| 275 |
# Performance tracking
|
| 276 |
self.frame_times = deque(maxlen=100)
|
| 277 |
self.audio_times = deque(maxlen=100)
|
| 278 |
+
self._metrics = get_enhanced_metrics()
|
| 279 |
|
| 280 |
# Processing locks
|
| 281 |
self.video_lock = threading.Lock()
|
|
|
|
| 290 |
"""Initialize all models"""
|
| 291 |
logger.info("Initializing real-time avatar pipeline...")
|
| 292 |
|
| 293 |
+
# Face detector load may be synchronous; run in executor to avoid blocking loop
|
| 294 |
+
loop = asyncio.get_running_loop()
|
| 295 |
+
try:
|
| 296 |
+
fd_ok = await loop.run_in_executor(None, self.face_detector.load_model)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
logger.error(f"Face detector load failed: {e}")
|
| 299 |
+
fd_ok = False
|
| 300 |
+
|
| 301 |
+
# Load async models and optional safe models in parallel
|
| 302 |
+
lp_task = self.liveportrait.load_models()
|
| 303 |
+
rvc_task = self.rvc.load_model()
|
| 304 |
+
scrfd_task = self.safe_loader.safe_load_scrfd()
|
| 305 |
+
lp_safe_task = self.safe_loader.safe_load_liveportrait()
|
| 306 |
+
|
| 307 |
+
results = await asyncio.gather(lp_task, rvc_task, scrfd_task, lp_safe_task, return_exceptions=True)
|
| 308 |
+
# Normalize booleans from tasks
|
| 309 |
+
async_ok = sum(1 for r in results if r is True)
|
| 310 |
+
success_count = async_ok + (1 if fd_ok else 0)
|
| 311 |
+
logger.info(f"Loaded components - FaceDetector: {fd_ok}, LivePortrait: {results[0]}, RVC: {results[1]}, SCRFD(safe): {results[2]}, LivePortrait(safe): {results[3]}")
|
| 312 |
|
| 313 |
+
if (fd_ok and (results[0] is True or results[3] is True)) or (fd_ok and results[1] is True):
|
| 314 |
+
# Require face detector + (any of liveportrait variants or RVC) to proceed
|
| 315 |
self.loaded = True
|
| 316 |
logger.info("Pipeline initialization successful")
|
| 317 |
return True
|
|
|
|
| 323 |
"""Set reference frame for avatar"""
|
| 324 |
try:
|
| 325 |
# Detect face in reference frame
|
| 326 |
+
bbox = None
|
| 327 |
+
confidence = 0.0
|
| 328 |
+
# Prefer safe SCRFD if available
|
| 329 |
+
try:
|
| 330 |
+
sb = self.safe_loader.safe_detect_face(frame)
|
| 331 |
+
if sb is not None:
|
| 332 |
+
bbox = sb
|
| 333 |
+
confidence = 1.0 # safe path doesn't provide score; assume strong if detected
|
| 334 |
+
except Exception:
|
| 335 |
+
pass
|
| 336 |
+
if bbox is None:
|
| 337 |
+
bbox, confidence = self.face_detector.detect_face(frame, 0)
|
| 338 |
|
| 339 |
if bbox is not None and confidence >= self.config.face_detection_threshold:
|
| 340 |
self.reference_frame = frame.copy()
|
|
|
|
| 373 |
return frame_resized
|
| 374 |
|
| 375 |
# Detect face in current frame
|
| 376 |
+
t0 = time.time()
|
| 377 |
+
bbox = None
|
| 378 |
+
confidence = 0.0
|
| 379 |
+
if self.safe_loader.scrfd_loaded:
|
| 380 |
+
try:
|
| 381 |
+
sb = self.safe_loader.safe_detect_face(frame_resized)
|
| 382 |
+
if sb is not None:
|
| 383 |
+
bbox = sb
|
| 384 |
+
confidence = 1.0
|
| 385 |
+
except Exception:
|
| 386 |
+
bbox = None
|
| 387 |
+
if bbox is None:
|
| 388 |
+
bbox, confidence = self.face_detector.detect_face(frame_resized, frame_idx)
|
| 389 |
+
self._metrics.record_component_timing('face_detection', (time.time() - t0) * 1000.0)
|
| 390 |
|
| 391 |
if self.reference_frame is None:
|
| 392 |
# No reference, keep camera as-is for stability until reference set
|
| 393 |
result_frame = frame_resized
|
| 394 |
elif bbox is not None and confidence >= self.config.face_redetect_threshold:
|
| 395 |
# Animate face using LivePortrait
|
| 396 |
+
t1 = time.time()
|
| 397 |
+
if self.liveportrait.loaded:
|
| 398 |
+
animated_frame = self.liveportrait.animate_face(
|
| 399 |
+
self.reference_frame, frame_resized
|
| 400 |
+
)
|
| 401 |
+
elif self.safe_loader.liveportrait_loaded:
|
| 402 |
+
animated_frame = self.safe_loader.safe_animate_face(
|
| 403 |
+
self.reference_frame, frame_resized
|
| 404 |
+
)
|
| 405 |
+
else:
|
| 406 |
+
animated_frame = frame_resized
|
| 407 |
+
self._metrics.record_component_timing('animation', (time.time() - t1) * 1000.0)
|
| 408 |
|
| 409 |
# Apply any post-processing with current quality settings
|
| 410 |
result_frame = self._post_process_frame(animated_frame, opt_settings)
|
|
|
|
| 419 |
# Record processing time
|
| 420 |
processing_time = (time.time() - start_time) * 1000
|
| 421 |
self.frame_times.append(processing_time)
|
| 422 |
+
self._metrics.record_video_timing(processing_time)
|
| 423 |
+
self._metrics.record_component_timing('face_detection', 0.0) # placeholder hooks
|
| 424 |
+
self._metrics.record_component_timing('animation', 0.0)
|
| 425 |
self.optimizer.latency_optimizer.record_latency("video_total", processing_time)
|
| 426 |
|
| 427 |
return result_frame
|
|
|
|
| 449 |
# Record processing time
|
| 450 |
processing_time = (time.time() - start_time) * 1000
|
| 451 |
self.audio_times.append(processing_time)
|
| 452 |
+
self._metrics.record_audio_timing(processing_time)
|
| 453 |
+
self._metrics.record_total_timing(processing_time)
|
| 454 |
+
self._metrics.record_component_timing('voice_processing', processing_time)
|
| 455 |
self.optimizer.latency_optimizer.record_latency("audio_total", processing_time)
|
| 456 |
|
| 457 |
return converted_audio
|
|
|
|
| 512 |
}
|
| 513 |
|
| 514 |
# Merge with optimizer stats
|
| 515 |
+
merged = {**pipeline_stats, "optimization": opt_stats}
|
| 516 |
+
# Enhance with additional percentiles/system metrics
|
| 517 |
+
merged = enhance_existing_stats(merged)
|
| 518 |
+
return merged
|
| 519 |
|
| 520 |
except Exception as e:
|
| 521 |
logger.error(f"Stats error: {e}")
|
enhanced_metrics.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced Performance Metrics for Existing Pipeline
|
| 3 |
+
Adds p50/p95/p99 latency tracking and GPU monitoring
|
| 4 |
+
Drop-in compatible with existing get_performance_stats()
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import psutil
|
| 9 |
+
import numpy as np
|
| 10 |
+
from collections import deque
|
| 11 |
+
from typing import Dict, Any, List
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EnhancedMetrics:
|
| 15 |
+
"""Enhanced metrics collection with percentiles"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, window_size: int = 100):
|
| 18 |
+
self.window_size = window_size
|
| 19 |
+
|
| 20 |
+
# Timing collections
|
| 21 |
+
self.video_times = deque(maxlen=window_size)
|
| 22 |
+
self.audio_times = deque(maxlen=window_size)
|
| 23 |
+
self.total_times = deque(maxlen=window_size)
|
| 24 |
+
|
| 25 |
+
# Component timing (for debugging)
|
| 26 |
+
self.component_times = {
|
| 27 |
+
'face_detection': deque(maxlen=window_size),
|
| 28 |
+
'animation': deque(maxlen=window_size),
|
| 29 |
+
'voice_processing': deque(maxlen=window_size),
|
| 30 |
+
'webrtc_encode': deque(maxlen=window_size)
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# FPS tracking
|
| 34 |
+
self.frame_timestamps = deque(maxlen=window_size)
|
| 35 |
+
|
| 36 |
+
# System monitoring
|
| 37 |
+
self.last_gpu_check = 0
|
| 38 |
+
self.gpu_memory_mb = 0
|
| 39 |
+
|
| 40 |
+
def record_video_timing(self, elapsed_ms: float):
|
| 41 |
+
self.video_times.append(elapsed_ms)
|
| 42 |
+
self.frame_timestamps.append(time.time())
|
| 43 |
+
|
| 44 |
+
def record_audio_timing(self, elapsed_ms: float):
|
| 45 |
+
self.audio_times.append(elapsed_ms)
|
| 46 |
+
|
| 47 |
+
def record_component_timing(self, component: str, elapsed_ms: float):
|
| 48 |
+
if component in self.component_times:
|
| 49 |
+
self.component_times[component].append(elapsed_ms)
|
| 50 |
+
|
| 51 |
+
def record_total_timing(self, elapsed_ms: float):
|
| 52 |
+
self.total_times.append(elapsed_ms)
|
| 53 |
+
|
| 54 |
+
def get_percentiles(self, values: List[float]) -> Dict[str, float]:
|
| 55 |
+
if not values:
|
| 56 |
+
return {'p50': 0.0, 'p95': 0.0, 'p99': 0.0}
|
| 57 |
+
arr = np.array(values)
|
| 58 |
+
return {
|
| 59 |
+
'p50': float(np.percentile(arr, 50)),
|
| 60 |
+
'p95': float(np.percentile(arr, 95)),
|
| 61 |
+
'p99': float(np.percentile(arr, 99))
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def get_fps(self) -> float:
|
| 65 |
+
if len(self.frame_timestamps) < 2:
|
| 66 |
+
return 0.0
|
| 67 |
+
timestamps = list(self.frame_timestamps)
|
| 68 |
+
time_span = timestamps[-1] - timestamps[0]
|
| 69 |
+
if time_span <= 0:
|
| 70 |
+
return 0.0
|
| 71 |
+
return (len(timestamps) - 1) / time_span
|
| 72 |
+
|
| 73 |
+
def get_gpu_memory(self) -> float:
|
| 74 |
+
current_time = time.time()
|
| 75 |
+
if current_time - self.last_gpu_check > 2.0:
|
| 76 |
+
try:
|
| 77 |
+
import torch
|
| 78 |
+
if torch.cuda.is_available():
|
| 79 |
+
self.gpu_memory_mb = torch.cuda.memory_allocated() / (1024 * 1024)
|
| 80 |
+
else:
|
| 81 |
+
self.gpu_memory_mb = 0
|
| 82 |
+
except ImportError:
|
| 83 |
+
self.gpu_memory_mb = 0
|
| 84 |
+
self.last_gpu_check = current_time
|
| 85 |
+
return self.gpu_memory_mb
|
| 86 |
+
|
| 87 |
+
def get_enhanced_stats(self) -> Dict[str, Any]:
|
| 88 |
+
video_list = list(self.video_times)
|
| 89 |
+
audio_list = list(self.audio_times)
|
| 90 |
+
total_list = list(self.total_times)
|
| 91 |
+
stats = {
|
| 92 |
+
"avg_video_latency_ms": float(np.mean(video_list)) if video_list else 0.0,
|
| 93 |
+
"avg_audio_latency_ms": float(np.mean(audio_list)) if audio_list else 0.0,
|
| 94 |
+
"video_fps": self.get_fps(),
|
| 95 |
+
"gpu_memory_used_mb": self.get_gpu_memory(),
|
| 96 |
+
"video_latency": {
|
| 97 |
+
"mean": float(np.mean(video_list)) if video_list else 0.0,
|
| 98 |
+
"std": float(np.std(video_list)) if video_list else 0.0,
|
| 99 |
+
**self.get_percentiles(video_list)
|
| 100 |
+
},
|
| 101 |
+
"audio_latency": {
|
| 102 |
+
"mean": float(np.mean(audio_list)) if audio_list else 0.0,
|
| 103 |
+
"std": float(np.std(audio_list)) if audio_list else 0.0,
|
| 104 |
+
**self.get_percentiles(audio_list)
|
| 105 |
+
},
|
| 106 |
+
"total_latency": {
|
| 107 |
+
"mean": float(np.mean(total_list)) if total_list else 0.0,
|
| 108 |
+
"std": float(np.std(total_list)) if total_list else 0.0,
|
| 109 |
+
**self.get_percentiles(total_list)
|
| 110 |
+
},
|
| 111 |
+
"components": {}
|
| 112 |
+
}
|
| 113 |
+
for component, times in self.component_times.items():
|
| 114 |
+
times_list = list(times)
|
| 115 |
+
if times_list:
|
| 116 |
+
stats["components"][component] = {
|
| 117 |
+
"mean": float(np.mean(times_list)),
|
| 118 |
+
**self.get_percentiles(times_list)
|
| 119 |
+
}
|
| 120 |
+
stats["system"] = {
|
| 121 |
+
"cpu_percent": psutil.cpu_percent(),
|
| 122 |
+
"memory_percent": psutil.virtual_memory().percent,
|
| 123 |
+
"active_connections": 1
|
| 124 |
+
}
|
| 125 |
+
return stats
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
_enhanced_metrics = EnhancedMetrics()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_enhanced_metrics() -> EnhancedMetrics:
|
| 132 |
+
return _enhanced_metrics
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def enhance_existing_stats(existing_stats: Dict[str, Any]) -> Dict[str, Any]:
|
| 136 |
+
enhanced = get_enhanced_metrics().get_enhanced_stats()
|
| 137 |
+
result = existing_stats.copy()
|
| 138 |
+
result.update(enhanced)
|
| 139 |
+
return result
|
safe_model_integration.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Safe Model Integration for Existing Avatar Pipeline
|
| 3 |
+
Incremental SCRFD + LivePortrait loading with feature flags
|
| 4 |
+
Maintains pass-through behavior until models are validated
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
ENABLE_SCRFD = os.getenv("MIRAGE_ENABLE_SCRFD", "0").lower() in ("1", "true", "yes")
|
| 16 |
+
ENABLE_LIVEPORTRAIT = os.getenv("MIRAGE_ENABLE_LIVEPORTRAIT", "0").lower() in ("1", "true", "yes")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SafeModelLoader:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.scrfd_loaded = False
|
| 22 |
+
self.liveportrait_loaded = False
|
| 23 |
+
self.models_dir = Path("models")
|
| 24 |
+
self.face_app = None
|
| 25 |
+
self.appearance_session = None
|
| 26 |
+
self.motion_session = None
|
| 27 |
+
|
| 28 |
+
async def safe_load_scrfd(self) -> bool:
|
| 29 |
+
if not ENABLE_SCRFD:
|
| 30 |
+
logger.info("SCRFD disabled by feature flag")
|
| 31 |
+
return False
|
| 32 |
+
try:
|
| 33 |
+
import insightface
|
| 34 |
+
models_root = self.models_dir / "insightface"
|
| 35 |
+
models_root.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
self.face_app = insightface.app.FaceAnalysis(name='buffalo_l', root=str(models_root))
|
| 37 |
+
ctx_id = 0 if os.getenv("CUDA_VISIBLE_DEVICES") != "-1" else -1
|
| 38 |
+
self.face_app.prepare(ctx_id=ctx_id)
|
| 39 |
+
self.scrfd_loaded = True
|
| 40 |
+
logger.info("SCRFD loaded successfully")
|
| 41 |
+
return True
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.warning(f"SCRFD loading failed: {e}")
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
async def safe_load_liveportrait(self) -> bool:
|
| 47 |
+
if not ENABLE_LIVEPORTRAIT:
|
| 48 |
+
logger.info("LivePortrait disabled by feature flag")
|
| 49 |
+
return False
|
| 50 |
+
try:
|
| 51 |
+
import onnxruntime as ort
|
| 52 |
+
lp_dir = self.models_dir / "liveportrait"
|
| 53 |
+
appearance_path = lp_dir / "appearance_feature_extractor.onnx"
|
| 54 |
+
motion_path = lp_dir / "motion_extractor.onnx"
|
| 55 |
+
if not appearance_path.exists():
|
| 56 |
+
logger.warning(f"LivePortrait appearance model not found: {appearance_path}")
|
| 57 |
+
return False
|
| 58 |
+
providers = []
|
| 59 |
+
if 'CUDAExecutionProvider' in ort.get_available_providers():
|
| 60 |
+
providers.append('CUDAExecutionProvider')
|
| 61 |
+
providers.append('CPUExecutionProvider')
|
| 62 |
+
self.appearance_session = ort.InferenceSession(str(appearance_path), providers=providers)
|
| 63 |
+
if motion_path.exists():
|
| 64 |
+
self.motion_session = ort.InferenceSession(str(motion_path), providers=providers)
|
| 65 |
+
self.liveportrait_loaded = True
|
| 66 |
+
logger.info("LivePortrait models loaded successfully")
|
| 67 |
+
return True
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.warning(f"LivePortrait loading failed: {e}")
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
def safe_detect_face(self, frame: np.ndarray) -> Optional[np.ndarray]:
|
| 73 |
+
if not self.scrfd_loaded or self.face_app is None:
|
| 74 |
+
return None
|
| 75 |
+
try:
|
| 76 |
+
faces = self.face_app.get(frame)
|
| 77 |
+
if len(faces) > 0:
|
| 78 |
+
face = max(faces, key=lambda x: x.det_score)
|
| 79 |
+
return face.bbox.astype(int)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.debug(f"Face detection error: {e}")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
def safe_animate_face(self, source: np.ndarray, driving: np.ndarray) -> np.ndarray:
|
| 85 |
+
if not self.liveportrait_loaded or self.appearance_session is None:
|
| 86 |
+
return source
|
| 87 |
+
try:
|
| 88 |
+
import cv2
|
| 89 |
+
enhanced = cv2.bilateralFilter(source, 5, 20, 20)
|
| 90 |
+
result = cv2.addWeighted(source, 0.9, enhanced, 0.1, 0)
|
| 91 |
+
return result
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.debug(f"Face animation error: {e}")
|
| 94 |
+
return source
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
_safe_loader = SafeModelLoader()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_safe_model_loader():
|
| 101 |
+
return _safe_loader
|
webrtc_connection_monitoring.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Safe WebRTC Connection Monitoring
|
| 3 |
+
Adds /webrtc/connections endpoint without breaking existing auth
|
| 4 |
+
Compatible with existing single-peer architecture
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def add_connection_monitoring(router: APIRouter, peer_state_getter):
|
| 12 |
+
@router.get("/connections")
|
| 13 |
+
async def get_connection_info():
|
| 14 |
+
try:
|
| 15 |
+
state = None
|
| 16 |
+
try:
|
| 17 |
+
state = peer_state_getter() if callable(peer_state_getter) else None
|
| 18 |
+
except Exception:
|
| 19 |
+
state = None
|
| 20 |
+
if state is None:
|
| 21 |
+
return {"active_connections": 0, "status": "no_active_connection"}
|
| 22 |
+
info = {
|
| 23 |
+
"active_connections": 1,
|
| 24 |
+
"status": "connected",
|
| 25 |
+
"connection_state": getattr(state, 'pc', None) and getattr(state.pc, 'connectionState', 'unknown'),
|
| 26 |
+
"uptime_seconds": time.time() - getattr(state, 'created', time.time()),
|
| 27 |
+
"ice_connection_state": getattr(state, 'pc', None) and getattr(state.pc, 'iceConnectionState', 'unknown'),
|
| 28 |
+
"control_channel_ready": getattr(state, 'control_channel_ready', False)
|
| 29 |
+
}
|
| 30 |
+
return info
|
| 31 |
+
except Exception as e:
|
| 32 |
+
return {"active_connections": 0, "status": "error", "error": str(e)}
|
webrtc_server.py
CHANGED
|
@@ -50,6 +50,10 @@ import numpy as np
|
|
| 50 |
import cv2
|
| 51 |
|
| 52 |
from avatar_pipeline import get_pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
logger = logging.getLogger(__name__)
|
| 55 |
router = APIRouter(prefix="/webrtc", tags=["webrtc"])
|
|
@@ -350,14 +354,30 @@ async def webrtc_offer(offer: Dict[str, Any], x_api_key: Optional[str] = Header(
|
|
| 350 |
def on_datachannel(channel):
|
| 351 |
logger.info("Data channel received: %s", channel.label)
|
| 352 |
if channel.label == "control":
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
@channel.on("message")
|
| 363 |
def on_message(message):
|
|
@@ -449,6 +469,7 @@ async def webrtc_offer(offer: Dict[str, Any], x_api_key: Optional[str] = Header(
|
|
| 449 |
answer = RTCSessionDescription(sdp=patched_sdp, type=answer.type)
|
| 450 |
await pc.setLocalDescription(answer)
|
| 451 |
|
|
|
|
| 452 |
_peer_state = PeerState(pc=pc, created=time.time())
|
| 453 |
|
| 454 |
logger.info("WebRTC answer created")
|
|
@@ -479,3 +500,13 @@ async def cleanup_peer(x_api_key: Optional[str] = Header(default=None), x_auth_t
|
|
| 479 |
pass
|
| 480 |
_peer_state = None
|
| 481 |
return {"status": "closed"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
import cv2
|
| 51 |
|
| 52 |
from avatar_pipeline import get_pipeline
|
| 53 |
+
try:
|
| 54 |
+
from webrtc_connection_monitoring import add_connection_monitoring # optional diagnostics
|
| 55 |
+
except Exception:
|
| 56 |
+
add_connection_monitoring = None
|
| 57 |
|
| 58 |
logger = logging.getLogger(__name__)
|
| 59 |
router = APIRouter(prefix="/webrtc", tags=["webrtc"])
|
|
|
|
| 354 |
def on_datachannel(channel):
|
| 355 |
logger.info("Data channel received: %s", channel.label)
|
| 356 |
if channel.label == "control":
|
| 357 |
+
# Mark control channel readiness on open/close
|
| 358 |
+
@channel.on("open")
|
| 359 |
+
def _on_open():
|
| 360 |
+
try:
|
| 361 |
+
if _peer_state is not None:
|
| 362 |
+
_peer_state.control_channel_ready = True
|
| 363 |
+
except Exception:
|
| 364 |
+
pass
|
| 365 |
+
|
| 366 |
+
@channel.on("close")
|
| 367 |
+
def _on_close():
|
| 368 |
+
try:
|
| 369 |
+
if _peer_state is not None:
|
| 370 |
+
_peer_state.control_channel_ready = False
|
| 371 |
+
except Exception:
|
| 372 |
+
pass
|
| 373 |
+
def send_metrics():
|
| 374 |
+
pipeline = get_pipeline()
|
| 375 |
+
stats = pipeline.get_performance_stats() if pipeline.loaded else {}
|
| 376 |
+
payload = json.dumps({"type": "metrics", "payload": stats})
|
| 377 |
+
try:
|
| 378 |
+
channel.send(payload)
|
| 379 |
+
except Exception:
|
| 380 |
+
logger.debug("Failed sending metrics")
|
| 381 |
|
| 382 |
@channel.on("message")
|
| 383 |
def on_message(message):
|
|
|
|
| 469 |
answer = RTCSessionDescription(sdp=patched_sdp, type=answer.type)
|
| 470 |
await pc.setLocalDescription(answer)
|
| 471 |
|
| 472 |
+
global _peer_state
|
| 473 |
_peer_state = PeerState(pc=pc, created=time.time())
|
| 474 |
|
| 475 |
logger.info("WebRTC answer created")
|
|
|
|
| 500 |
pass
|
| 501 |
_peer_state = None
|
| 502 |
return {"status": "closed"}
|
| 503 |
+
|
| 504 |
+
# Optional: connection monitoring endpoint for diagnostics
|
| 505 |
+
if add_connection_monitoring is not None:
|
| 506 |
+
try:
|
| 507 |
+
# Provide a getter to reflect live _peer_state rather than a stale snapshot
|
| 508 |
+
def _get_peer_state():
|
| 509 |
+
return _peer_state
|
| 510 |
+
add_connection_monitoring(router, _get_peer_state)
|
| 511 |
+
except Exception:
|
| 512 |
+
pass
|