text-to-3d-backend / models /depth_processor.py
aniket47's picture
Initial FastAPI backend for HF Spaces
86e7db6
raw
history blame
7.44 kB
"""
Depth processing module for converting 2D images to depth maps and 3D models
"""
import os
import logging
import tempfile
import numpy as np
import torch
from PIL import Image
from transformers import DPTImageProcessor, DPTForDepthEstimation
import open3d as o3d
import matplotlib.pyplot as plt
logger = logging.getLogger(__name__)
class DepthProcessor:
"""Handles depth estimation and 3D model generation"""
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.processor = None
self.model = None
self.temp_dir = tempfile.mkdtemp()
def load_model(self):
"""Load the DPT depth estimation model"""
try:
logger.info(f"πŸ”„ Loading DPT model on {self.device}...")
# Load processor and model
self.processor = DPTImageProcessor.from_pretrained("Intel/dpt-beit-large-512")
self.model = DPTForDepthEstimation.from_pretrained("Intel/dpt-beit-large-512")
self.model.to(self.device)
self.model.eval()
if self.device.type == "cuda":
logger.info(f"βœ… DPT model loaded on GPU: {torch.cuda.get_device_name(0)}")
else:
logger.info("βœ… DPT model loaded on CPU")
except Exception as e:
logger.error(f"❌ Failed to load DPT model: {str(e)}")
raise e
def generate_depth_map(self, image: Image.Image) -> np.ndarray:
"""Generate depth map from PIL Image"""
try:
# Prepare image for model
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate depth map
with torch.no_grad():
outputs = self.model(**inputs)
predicted_depth = outputs.predicted_depth
# Convert to numpy and normalize
depth = predicted_depth.squeeze().cpu().numpy()
depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
return depth_normalized
except Exception as e:
logger.error(f"❌ Error generating depth map: {str(e)}")
raise e
def save_depth_map_image(self, depth_map: np.ndarray, job_id: str) -> str:
"""Save depth map as image file"""
try:
# Create colorized depth map
plt.figure(figsize=(10, 10))
plt.imshow(depth_map, cmap='plasma')
plt.axis('off')
plt.tight_layout()
# Save image
depth_path = os.path.join(self.temp_dir, f"depth_{job_id}.png")
plt.savefig(depth_path, bbox_inches='tight', pad_inches=0, dpi=150)
plt.close()
return depth_path
except Exception as e:
logger.error(f"❌ Error saving depth map image: {str(e)}")
raise e
def create_3d_model(self, image: Image.Image, depth_map: np.ndarray, job_id: str) -> str:
"""Create 3D OBJ model from image and depth map"""
try:
# Convert image to numpy array
img_array = np.array(image)
h, w = depth_map.shape
# Create point cloud
points = []
colors = []
# Sample points (reduce resolution for performance)
step = max(1, min(h, w) // 200) # Target ~200x200 points max
for y in range(0, h, step):
for x in range(0, w, step):
# Get depth value (invert for proper 3D orientation)
z = (1.0 - depth_map[y, x]) * 50.0 # Scale depth
# Skip points that are too far
if z > 45.0:
continue
# Add point
points.append([x / w - 0.5, (h - y) / h - 0.5, z])
# Add color
if len(img_array.shape) == 3:
colors.append(img_array[y, x] / 255.0)
else:
colors.append([0.7, 0.7, 0.7]) # Gray for grayscale
if not points:
raise ValueError("No valid points generated for 3D model")
# Create Open3D point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.array(points))
pcd.colors = o3d.utility.Vector3dVector(np.array(colors))
# Estimate normals
pcd.estimate_normals()
# Create mesh using Poisson reconstruction
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
pcd, depth=8, width=0, scale=1.1, linear_fit=False
)
# Remove degenerate triangles and unreferenced vertices
mesh.remove_degenerate_triangles()
mesh.remove_duplicated_triangles()
mesh.remove_duplicated_vertices()
mesh.remove_non_manifold_edges()
# Smooth the mesh
mesh = mesh.filter_smooth_simple(number_of_iterations=2)
# Save as OBJ file
obj_path = os.path.join(self.temp_dir, f"model_{job_id}.obj")
o3d.io.write_triangle_mesh(obj_path, mesh)
logger.info(f"βœ… 3D model created: {len(mesh.vertices)} vertices, {len(mesh.triangles)} triangles")
return obj_path
except Exception as e:
logger.error(f"❌ Error creating 3D model: {str(e)}")
raise e
def process_image_to_3d(self, image: Image.Image, job_id: str) -> dict:
"""Complete pipeline: image -> depth map -> 3D model"""
try:
logger.info(f"πŸ”„ Processing image to 3D model (Job: {job_id})")
# Resize image if too large (for performance)
max_size = 512
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
logger.info(f"πŸ“ Resized image to {new_size}")
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Generate depth map
depth_map = self.generate_depth_map(image)
# Save depth map as image
depth_map_path = self.save_depth_map_image(depth_map, job_id)
# Create 3D model
obj_path = self.create_3d_model(image, depth_map, job_id)
return {
'depth_map': depth_map,
'depth_map_path': depth_map_path,
'obj_path': obj_path,
'success': True
}
except Exception as e:
logger.error(f"❌ Error in image-to-3D pipeline: {str(e)}")
return {
'success': False,
'error': str(e)
}