aniket47 commited on
Commit
fa13587
Β·
1 Parent(s): a09688b

Upgrade to Stability AI SDXL model for superior image quality

Browse files
Files changed (2) hide show
  1. models/image_generator.py +64 -19
  2. requirements.txt +2 -1
models/image_generator.py CHANGED
@@ -24,17 +24,18 @@ class ImageGenerator:
24
  def load_model(self):
25
  """Load the Stable Diffusion model"""
26
  try:
27
- logger.info(f"πŸ”„ Loading Stable Diffusion model on {self.device}...")
28
 
29
- # Use a smaller, faster model for better performance on free tier
30
- model_id = "runwayml/stable-diffusion-v1-5"
31
 
32
  # Load pipeline
33
  self.pipeline = StableDiffusionPipeline.from_pretrained(
34
  model_id,
35
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
36
  safety_checker=None, # Disable safety checker for faster inference
37
- requires_safety_checker=False
 
38
  )
39
 
40
  self.pipeline.to(self.device)
@@ -43,39 +44,78 @@ class ImageGenerator:
43
  if hasattr(self.pipeline, "enable_attention_slicing"):
44
  self.pipeline.enable_attention_slicing()
45
 
 
 
 
 
 
 
 
 
46
  # Only enable CPU offloading if CUDA is available but we want to save memory
47
  # For pure CPU mode, keep everything on CPU
48
  if self.device.type == "cuda":
49
  # Enable model offloading to save GPU memory
50
  self.pipeline.enable_sequential_cpu_offload()
51
- logger.info(f"βœ… Stable Diffusion loaded on GPU: {torch.cuda.get_device_name(0)}")
52
  else:
53
  # For CPU-only mode, don't use offloading
54
- logger.info("βœ… Stable Diffusion loaded on CPU")
55
 
56
  except Exception as e:
57
- logger.error(f"❌ Failed to load Stable Diffusion model: {str(e)}")
58
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def generate_image(self, prompt: str, negative_prompt: str = None) -> dict:
61
  """Generate image from text prompt"""
62
  try:
63
  logger.info(f"🎨 Generating image for prompt: '{prompt}'")
64
 
65
- # Default negative prompt for better quality
66
  if negative_prompt is None:
67
- negative_prompt = "blurry, low quality, distorted, deformed, ugly, bad anatomy, worst quality, low res"
68
 
69
- # Enhanced prompt for 3D-suitable images
70
- enhanced_prompt = f"{prompt}, high quality, detailed, clear lighting, suitable for 3D modeling, photorealistic"
71
 
72
- # Generation parameters - optimized for quality
73
  generator = torch.Generator(device=self.device).manual_seed(42) # Fixed seed for consistency
74
 
75
- # Higher quality parameters - even for CPU
76
- num_steps = 25 if self.device.type == "cpu" else 50
77
- width = 512 # Full resolution for better quality
78
- height = 512
 
 
 
 
 
 
79
 
80
  logger.info(f"πŸ–ΌοΈ Generating {width}x{height} image with {num_steps} steps on {self.device}")
81
 
@@ -85,7 +125,7 @@ class ImageGenerator:
85
  prompt=enhanced_prompt,
86
  negative_prompt=negative_prompt,
87
  num_inference_steps=num_steps,
88
- guidance_scale=8.5, # Higher guidance for better quality
89
  width=width,
90
  height=height,
91
  generator=generator
@@ -93,6 +133,11 @@ class ImageGenerator:
93
 
94
  image = result.images[0]
95
 
 
 
 
 
 
96
  # Convert to bytes for storage
97
  img_bytes = io.BytesIO()
98
  image.save(img_bytes, format='PNG', quality=95)
@@ -103,7 +148,7 @@ class ImageGenerator:
103
  torch.cuda.empty_cache()
104
  gc.collect()
105
 
106
- logger.info("βœ… Image generated successfully")
107
 
108
  return {
109
  'image_pil': image,
 
24
  def load_model(self):
25
  """Load the Stable Diffusion model"""
26
  try:
27
+ logger.info(f"πŸ”„ Loading Stability AI model on {self.device}...")
28
 
29
+ # Use Stability AI's SDXL model for highest quality
30
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
31
 
32
  # Load pipeline
33
  self.pipeline = StableDiffusionPipeline.from_pretrained(
34
  model_id,
35
  torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
36
  safety_checker=None, # Disable safety checker for faster inference
37
+ requires_safety_checker=False,
38
+ use_safetensors=True
39
  )
40
 
41
  self.pipeline.to(self.device)
 
44
  if hasattr(self.pipeline, "enable_attention_slicing"):
45
  self.pipeline.enable_attention_slicing()
46
 
47
+ # Enable xformers for better performance if available
48
+ if hasattr(self.pipeline, "enable_xformers_memory_efficient_attention"):
49
+ try:
50
+ self.pipeline.enable_xformers_memory_efficient_attention()
51
+ logger.info("βœ… XFormers memory efficient attention enabled")
52
+ except:
53
+ logger.info("ℹ️ XFormers not available, using default attention")
54
+
55
  # Only enable CPU offloading if CUDA is available but we want to save memory
56
  # For pure CPU mode, keep everything on CPU
57
  if self.device.type == "cuda":
58
  # Enable model offloading to save GPU memory
59
  self.pipeline.enable_sequential_cpu_offload()
60
+ logger.info(f"βœ… Stability AI SDXL loaded on GPU: {torch.cuda.get_device_name(0)}")
61
  else:
62
  # For CPU-only mode, don't use offloading
63
+ logger.info("βœ… Stability AI SDXL loaded on CPU")
64
 
65
  except Exception as e:
66
+ logger.error(f"❌ Failed to load Stability AI model: {str(e)}")
67
+ # Fallback to standard SD 1.5 if SDXL fails
68
+ logger.info("πŸ”„ Falling back to Stable Diffusion v1.5...")
69
+ try:
70
+ model_id = "runwayml/stable-diffusion-v1-5"
71
+
72
+ self.pipeline = StableDiffusionPipeline.from_pretrained(
73
+ model_id,
74
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
75
+ safety_checker=None,
76
+ requires_safety_checker=False
77
+ )
78
+
79
+ self.pipeline.to(self.device)
80
+
81
+ if hasattr(self.pipeline, "enable_attention_slicing"):
82
+ self.pipeline.enable_attention_slicing()
83
+
84
+ if self.device.type == "cuda":
85
+ self.pipeline.enable_sequential_cpu_offload()
86
+ logger.info(f"βœ… Fallback SD v1.5 loaded on GPU: {torch.cuda.get_device_name(0)}")
87
+ else:
88
+ logger.info("βœ… Fallback SD v1.5 loaded on CPU")
89
+
90
+ except Exception as fallback_error:
91
+ logger.error(f"❌ Fallback model also failed: {str(fallback_error)}")
92
+ raise fallback_error
93
 
94
  def generate_image(self, prompt: str, negative_prompt: str = None) -> dict:
95
  """Generate image from text prompt"""
96
  try:
97
  logger.info(f"🎨 Generating image for prompt: '{prompt}'")
98
 
99
+ # Enhanced negative prompt for Stability AI models
100
  if negative_prompt is None:
101
+ negative_prompt = "blurry, low quality, distorted, deformed, ugly, bad anatomy, worst quality, low res, jpeg artifacts, watermark, signature"
102
 
103
+ # Enhanced prompt for 3D-suitable images with Stability AI style
104
+ enhanced_prompt = f"{prompt}, masterpiece, best quality, highly detailed, sharp focus, professional photography, suitable for 3D modeling, photorealistic, 8k uhd"
105
 
106
+ # Generation parameters - optimized for Stability AI models
107
  generator = torch.Generator(device=self.device).manual_seed(42) # Fixed seed for consistency
108
 
109
+ # SDXL optimized parameters
110
+ num_steps = 30 if self.device.type == "cpu" else 50 # SDXL works best with more steps
111
+ width = 1024 if self.device.type == "cuda" else 512 # SDXL native resolution is 1024x1024
112
+ height = 1024 if self.device.type == "cuda" else 512
113
+ guidance_scale = 7.0 # SDXL works best with lower guidance scale
114
+
115
+ # For CPU, use smaller resolution to manage memory
116
+ if self.device.type == "cpu":
117
+ width, height = 512, 512
118
+ num_steps = 25 # Fewer steps for CPU but still good quality
119
 
120
  logger.info(f"πŸ–ΌοΈ Generating {width}x{height} image with {num_steps} steps on {self.device}")
121
 
 
125
  prompt=enhanced_prompt,
126
  negative_prompt=negative_prompt,
127
  num_inference_steps=num_steps,
128
+ guidance_scale=guidance_scale,
129
  width=width,
130
  height=height,
131
  generator=generator
 
133
 
134
  image = result.images[0]
135
 
136
+ # Resize to 512x512 for consistency if generated at higher resolution
137
+ if width > 512 or height > 512:
138
+ image = image.resize((512, 512), Image.Resampling.LANCZOS)
139
+ logger.info("πŸ”„ Resized image from 1024x1024 to 512x512 for processing")
140
+
141
  # Convert to bytes for storage
142
  img_bytes = io.BytesIO()
143
  image.save(img_bytes, format='PNG', quality=95)
 
148
  torch.cuda.empty_cache()
149
  gc.collect()
150
 
151
+ logger.info("βœ… Image generated successfully with Stability AI model")
152
 
153
  return {
154
  'image_pil': image,
requirements.txt CHANGED
@@ -17,4 +17,5 @@ safetensors==0.4.2
17
  huggingface_hub==0.20.2
18
  requests==2.31.0
19
  trimesh==4.0.5
20
- scipy==1.11.4
 
 
17
  huggingface_hub==0.20.2
18
  requests==2.31.0
19
  trimesh==4.0.5
20
+ scipy==1.11.4
21
+ xformers==0.0.22