""" Pre-download models for faster startup """ import os import sys def preload_models(): """Pre-download all required models""" try: print("🔄 Pre-downloading Stability AI SDXL model...") from diffusers import StableDiffusionPipeline import torch # Set cache directory cache_dir = os.environ.get('HF_HOME', '/app/model_cache') # Download SDXL model try: pipeline = StableDiffusionPipeline.from_pretrained( 'stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float32, safety_checker=None, requires_safety_checker=False, cache_dir=cache_dir ) print("✅ SDXL model downloaded successfully") except Exception as e: print(f"⚠️ SDXL download failed, downloading fallback: {e}") # Download fallback model pipeline = StableDiffusionPipeline.from_pretrained( 'runwayml/stable-diffusion-v1-5', torch_dtype=torch.float32, safety_checker=None, requires_safety_checker=False, cache_dir=cache_dir ) print("✅ SD v1.5 fallback model downloaded successfully") # Also pre-download depth estimation model print("🔄 Pre-downloading depth estimation model...") from transformers import DPTImageProcessor, DPTForDepthEstimation DPTImageProcessor.from_pretrained('Intel/dpt-beit-large-512', cache_dir=cache_dir) DPTForDepthEstimation.from_pretrained('Intel/dpt-beit-large-512', cache_dir=cache_dir) print("✅ Depth estimation model downloaded successfully") print("🎉 All models pre-loaded successfully!") except Exception as e: print(f"❌ Error pre-loading models: {e}") sys.exit(1) if __name__ == "__main__": preload_models()