text-to-3d-backend / models /depth_processor.py
aniket47's picture
Final update to use Intel DPT BEIT Large 512 model as requested
3c28cb2
raw
history blame
8.36 kB
"""
Depth processing module for converting 2D images to depth maps and 3D models
"""
import os
import logging
import tempfile
import numpy as np
import torch
from PIL import Image
from transformers import DPTImageProcessor, DPTForDepthEstimation
import trimesh
import matplotlib.pyplot as plt
logger = logging.getLogger(__name__)
class DepthProcessor:
"""Handles depth estimation and 3D model generation"""
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.processor = None
self.model = None
self.temp_dir = tempfile.mkdtemp()
def load_model(self):
"""Load the DPT depth estimation model"""
try:
logger.info(f"πŸ”„ Loading DPT model on {self.device}...")
# Load processor and model
model_name = "Intel/dpt-beit-large-512" # Using the large model as requested
self.processor = DPTImageProcessor.from_pretrained(model_name)
self.model = DPTForDepthEstimation.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
if self.device.type == "cuda":
logger.info(f"βœ… DPT model loaded on GPU: {torch.cuda.get_device_name(0)}")
else:
logger.info("βœ… DPT model loaded on CPU")
except Exception as e:
logger.error(f"❌ Failed to load DPT model: {str(e)}")
raise e
def generate_depth_map(self, image: Image.Image) -> np.ndarray:
"""Generate depth map from PIL Image"""
try:
# Prepare image for model
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate depth map
with torch.no_grad():
outputs = self.model(**inputs)
predicted_depth = outputs.predicted_depth
# Convert to numpy and normalize
depth = predicted_depth.squeeze().cpu().numpy()
depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
return depth_normalized
except Exception as e:
logger.error(f"❌ Error generating depth map: {str(e)}")
raise e
def save_depth_map_image(self, depth_map: np.ndarray, job_id: str) -> str:
"""Save depth map as image file"""
try:
# Create colorized depth map
plt.figure(figsize=(10, 10))
plt.imshow(depth_map, cmap='plasma')
plt.axis('off')
plt.tight_layout()
# Save image
depth_path = os.path.join(self.temp_dir, f"depth_{job_id}.png")
plt.savefig(depth_path, bbox_inches='tight', pad_inches=0, dpi=150)
plt.close()
return depth_path
except Exception as e:
logger.error(f"❌ Error saving depth map image: {str(e)}")
raise e
def create_3d_model(self, image: Image.Image, depth_map: np.ndarray, job_id: str) -> str:
"""Create 3D OBJ model from image and depth map"""
try:
# Convert image to numpy array
img_array = np.array(image)
h, w = depth_map.shape
# Create point cloud
vertices = []
faces = []
colors = []
# Sample points (reduce resolution for performance)
step = max(1, min(h, w) // 100) # Target ~100x100 points max
vertex_map = {}
vertex_idx = 0
for y in range(0, h - step, step):
for x in range(0, w - step, step):
# Get depth values for a quad
quad_depths = [
(1.0 - depth_map[y, x]) * 30.0,
(1.0 - depth_map[y, x + step]) * 30.0,
(1.0 - depth_map[y + step, x]) * 30.0,
(1.0 - depth_map[y + step, x + step]) * 30.0
]
# Skip if any depth is too far
if any(z > 25.0 for z in quad_depths):
continue
# Create vertices for the quad
quad_vertices = [
[x / w - 0.5, (h - y) / h - 0.5, quad_depths[0]],
[(x + step) / w - 0.5, (h - y) / h - 0.5, quad_depths[1]],
[x / w - 0.5, (h - (y + step)) / h - 0.5, quad_depths[2]],
[(x + step) / w - 0.5, (h - (y + step)) / h - 0.5, quad_depths[3]]
]
# Add vertices and colors
quad_indices = []
for i, vertex in enumerate(quad_vertices):
vertices.append(vertex)
quad_indices.append(vertex_idx)
vertex_idx += 1
# Add color
py = y if i < 2 else y + step
px = x if i % 2 == 0 else x + step
if len(img_array.shape) == 3:
colors.append(img_array[min(py, h-1), min(px, w-1)] / 255.0)
else:
colors.append([0.7, 0.7, 0.7])
# Create two triangular faces for the quad
faces.append([quad_indices[0], quad_indices[1], quad_indices[2]])
faces.append([quad_indices[1], quad_indices[3], quad_indices[2]])
if not vertices:
raise ValueError("No valid vertices generated for 3D model")
# Create trimesh object
mesh = trimesh.Trimesh(
vertices=np.array(vertices),
faces=np.array(faces),
vertex_colors=np.array(colors)
)
# Remove degenerate faces and duplicate vertices
mesh.remove_degenerate_faces()
mesh.remove_duplicate_faces()
mesh.merge_vertices()
# Save as OBJ file
obj_path = os.path.join(self.temp_dir, f"model_{job_id}.obj")
mesh.export(obj_path)
logger.info(f"βœ… 3D model created: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces")
return obj_path
except Exception as e:
logger.error(f"❌ Error creating 3D model: {str(e)}")
raise e
def process_image_to_3d(self, image: Image.Image, job_id: str) -> dict:
"""Complete pipeline: image -> depth map -> 3D model"""
try:
logger.info(f"πŸ”„ Processing image to 3D model (Job: {job_id})")
# Resize image if too large (for performance)
max_size = 512
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
logger.info(f"πŸ“ Resized image to {new_size}")
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Generate depth map
depth_map = self.generate_depth_map(image)
# Save depth map as image
depth_map_path = self.save_depth_map_image(depth_map, job_id)
# Create 3D model
obj_path = self.create_3d_model(image, depth_map, job_id)
return {
'depth_map': depth_map,
'depth_map_path': depth_map_path,
'obj_path': obj_path,
'success': True
}
except Exception as e:
logger.error(f"❌ Error in image-to-3D pipeline: {str(e)}")
return {
'success': False,
'error': str(e)
}