aniket47 commited on
Commit
6e0e6cc
Β·
1 Parent(s): fa13587

Speed optimizations: pre-download models, optimized caching, model CPU offloading

Browse files
Files changed (3) hide show
  1. Dockerfile +18 -0
  2. models/image_generator.py +35 -10
  3. preload_models.py +55 -0
Dockerfile CHANGED
@@ -2,11 +2,29 @@ FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
 
 
 
 
 
 
5
  COPY requirements.txt .
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
 
8
  COPY . .
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  EXPOSE 7860
11
 
12
  CMD ["python", "app.py"]
 
2
 
3
  WORKDIR /app
4
 
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ git \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements and install Python dependencies
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
+ # Copy application code
15
  COPY . .
16
 
17
+ # Set environment variables for HuggingFace caching
18
+ ENV HF_HOME=/app/model_cache
19
+ ENV TRANSFORMERS_CACHE=/app/model_cache
20
+ ENV HF_DATASETS_CACHE=/app/model_cache
21
+
22
+ # Create cache directory with proper permissions
23
+ RUN mkdir -p /app/model_cache && chmod 755 /app/model_cache
24
+
25
+ # Pre-download models during build time for faster startup
26
+ RUN python preload_models.py
27
+
28
  EXPOSE 7860
29
 
30
  CMD ["python", "app.py"]
models/image_generator.py CHANGED
@@ -22,20 +22,27 @@ class ImageGenerator:
22
  self.temp_dir = tempfile.mkdtemp()
23
 
24
  def load_model(self):
25
- """Load the Stable Diffusion model"""
26
  try:
27
  logger.info(f"πŸ”„ Loading Stability AI model on {self.device}...")
28
 
29
  # Use Stability AI's SDXL model for highest quality
30
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
31
 
32
- # Load pipeline
 
 
 
33
  self.pipeline = StableDiffusionPipeline.from_pretrained(
34
  model_id,
35
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
36
  safety_checker=None, # Disable safety checker for faster inference
37
  requires_safety_checker=False,
38
- use_safetensors=True
 
 
 
 
39
  )
40
 
41
  self.pipeline.to(self.device)
@@ -52,14 +59,21 @@ class ImageGenerator:
52
  except:
53
  logger.info("ℹ️ XFormers not available, using default attention")
54
 
55
- # Only enable CPU offloading if CUDA is available but we want to save memory
56
- # For pure CPU mode, keep everything on CPU
57
- if self.device.type == "cuda":
58
- # Enable model offloading to save GPU memory
 
 
 
 
 
 
59
  self.pipeline.enable_sequential_cpu_offload()
 
 
60
  logger.info(f"βœ… Stability AI SDXL loaded on GPU: {torch.cuda.get_device_name(0)}")
61
  else:
62
- # For CPU-only mode, don't use offloading
63
  logger.info("βœ… Stability AI SDXL loaded on CPU")
64
 
65
  except Exception as e:
@@ -73,7 +87,9 @@ class ImageGenerator:
73
  model_id,
74
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
75
  safety_checker=None,
76
- requires_safety_checker=False
 
 
77
  )
78
 
79
  self.pipeline.to(self.device)
@@ -81,8 +97,17 @@ class ImageGenerator:
81
  if hasattr(self.pipeline, "enable_attention_slicing"):
82
  self.pipeline.enable_attention_slicing()
83
 
 
 
 
 
 
 
84
  if self.device.type == "cuda":
85
- self.pipeline.enable_sequential_cpu_offload()
 
 
 
86
  logger.info(f"βœ… Fallback SD v1.5 loaded on GPU: {torch.cuda.get_device_name(0)}")
87
  else:
88
  logger.info("βœ… Fallback SD v1.5 loaded on CPU")
 
22
  self.temp_dir = tempfile.mkdtemp()
23
 
24
  def load_model(self):
25
+ """Load the Stable Diffusion model with optimized caching"""
26
  try:
27
  logger.info(f"πŸ”„ Loading Stability AI model on {self.device}...")
28
 
29
  # Use Stability AI's SDXL model for highest quality
30
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
31
 
32
+ # Optimize caching for faster subsequent loads
33
+ cache_dir = os.environ.get("HF_HOME", "/tmp/huggingface_cache")
34
+
35
+ # Load pipeline with optimized settings
36
  self.pipeline = StableDiffusionPipeline.from_pretrained(
37
  model_id,
38
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
39
  safety_checker=None, # Disable safety checker for faster inference
40
  requires_safety_checker=False,
41
+ use_safetensors=True,
42
+ cache_dir=cache_dir,
43
+ resume_download=True, # Resume interrupted downloads
44
+ local_files_only=False, # Allow downloads but prefer cache
45
+ variant="fp16" if self.device.type == "cuda" else None # Use fp16 variant for GPU
46
  )
47
 
48
  self.pipeline.to(self.device)
 
59
  except:
60
  logger.info("ℹ️ XFormers not available, using default attention")
61
 
62
+ # Enable model CPU offloading for memory optimization
63
+ if hasattr(self.pipeline, "enable_model_cpu_offload"):
64
+ try:
65
+ self.pipeline.enable_model_cpu_offload()
66
+ logger.info("βœ… Model CPU offloading enabled for memory optimization")
67
+ except:
68
+ logger.info("ℹ️ CPU offloading not available")
69
+
70
+ # Only enable sequential CPU offloading if model CPU offload fails
71
+ if self.device.type == "cuda" and not hasattr(self.pipeline, "enable_model_cpu_offload"):
72
  self.pipeline.enable_sequential_cpu_offload()
73
+
74
+ if self.device.type == "cuda":
75
  logger.info(f"βœ… Stability AI SDXL loaded on GPU: {torch.cuda.get_device_name(0)}")
76
  else:
 
77
  logger.info("βœ… Stability AI SDXL loaded on CPU")
78
 
79
  except Exception as e:
 
87
  model_id,
88
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
89
  safety_checker=None,
90
+ requires_safety_checker=False,
91
+ cache_dir=cache_dir,
92
+ resume_download=True
93
  )
94
 
95
  self.pipeline.to(self.device)
 
97
  if hasattr(self.pipeline, "enable_attention_slicing"):
98
  self.pipeline.enable_attention_slicing()
99
 
100
+ if hasattr(self.pipeline, "enable_xformers_memory_efficient_attention"):
101
+ try:
102
+ self.pipeline.enable_xformers_memory_efficient_attention()
103
+ except:
104
+ pass
105
+
106
  if self.device.type == "cuda":
107
+ if hasattr(self.pipeline, "enable_model_cpu_offload"):
108
+ self.pipeline.enable_model_cpu_offload()
109
+ else:
110
+ self.pipeline.enable_sequential_cpu_offload()
111
  logger.info(f"βœ… Fallback SD v1.5 loaded on GPU: {torch.cuda.get_device_name(0)}")
112
  else:
113
  logger.info("βœ… Fallback SD v1.5 loaded on CPU")
preload_models.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pre-download models for faster startup
3
+ """
4
+ import os
5
+ import sys
6
+
7
+ def preload_models():
8
+ """Pre-download all required models"""
9
+ try:
10
+ print("πŸ”„ Pre-downloading Stability AI SDXL model...")
11
+
12
+ from diffusers import StableDiffusionPipeline
13
+ import torch
14
+
15
+ # Set cache directory
16
+ cache_dir = os.environ.get('HF_HOME', '/app/model_cache')
17
+
18
+ # Download SDXL model
19
+ try:
20
+ pipeline = StableDiffusionPipeline.from_pretrained(
21
+ 'stabilityai/stable-diffusion-xl-base-1.0',
22
+ torch_dtype=torch.float32,
23
+ safety_checker=None,
24
+ requires_safety_checker=False,
25
+ cache_dir=cache_dir
26
+ )
27
+ print("βœ… SDXL model downloaded successfully")
28
+ except Exception as e:
29
+ print(f"⚠️ SDXL download failed, downloading fallback: {e}")
30
+ # Download fallback model
31
+ pipeline = StableDiffusionPipeline.from_pretrained(
32
+ 'runwayml/stable-diffusion-v1-5',
33
+ torch_dtype=torch.float32,
34
+ safety_checker=None,
35
+ requires_safety_checker=False,
36
+ cache_dir=cache_dir
37
+ )
38
+ print("βœ… SD v1.5 fallback model downloaded successfully")
39
+
40
+ # Also pre-download depth estimation model
41
+ print("πŸ”„ Pre-downloading depth estimation model...")
42
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
43
+
44
+ DPTImageProcessor.from_pretrained('Intel/dpt-beit-large-512', cache_dir=cache_dir)
45
+ DPTForDepthEstimation.from_pretrained('Intel/dpt-beit-large-512', cache_dir=cache_dir)
46
+ print("βœ… Depth estimation model downloaded successfully")
47
+
48
+ print("πŸŽ‰ All models pre-loaded successfully!")
49
+
50
+ except Exception as e:
51
+ print(f"❌ Error pre-loading models: {e}")
52
+ sys.exit(1)
53
+
54
+ if __name__ == "__main__":
55
+ preload_models()