import torch import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt def visualize_single_feature_map(feature_tensor, original_image_shape, title): """Generic function to plot a 4D feature map (from SegFormer or DPT) as a heatmap.""" # Squeeze batch dimension, average across channels (C) # Shape: (1, C, H_map, W_map) -> (H_map, W_map) feature_map = feature_tensor.squeeze(0).mean(dim=0) # Normalize to [0, 1] feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min()) # Upsample the map to the original image size H, W, _ = original_image_shape heatmap = F.interpolate( feature_map.unsqueeze(0).unsqueeze(0), # (1, 1, H_map, W_map) size=(H, W), mode='bilinear', align_corners=False ).squeeze().cpu().numpy() plt.title(title) plt.imshow(original_image_shape) plt.imshow(heatmap, cmap='jet', alpha=0.5) plt.axis('off') def visualize_clip_attention(attentions, original_image_shape): """ Simplified Attention Rollout: Visualizes the attention of the CLS token in the last layer. """ # Use the attention from the LAST layer (attentions is a tuple of layers) attention_last_layer = attentions[-1].squeeze(0) # (Num_Heads, Num_Patches+1, Num_Patches+1) # Average attention across all heads # (Num_Heads, N, N) -> (N, N), where N = Num_Patches + 1 attention_matrix = attention_last_layer.mean(dim=0) # We are interested in the attention the CLS token (index 0) pays to other tokens (index 1 to N-1) # Resulting shape: (Num_Patches,) cls_attention_to_patches = attention_matrix[0, 1:] # Reshape the vector to a square grid (e.g., 16x16 or 12x12 for ViT-B/16) num_patches = cls_attention_to_patches.size(0) grid_size = int(np.sqrt(num_patches)) if grid_size * grid_size != num_patches: print(f"Warning: Unexpected patch count {num_patches}. Skipping CLIP visualization.") return attention_grid = cls_attention_to_patches.reshape(grid_size, grid_size).cpu().numpy() # Normalize and Upsample H, W, _ = original_image_shape heatmap = F.interpolate( torch.from_numpy(attention_grid).float().unsqueeze(0).unsqueeze(0), size=(H, W), mode='bilinear', align_corners=False ).squeeze().cpu().numpy() plt.title("CLIP: CLS Token Attention") plt.imshow(original_image_shape) plt.imshow(heatmap, cmap='viridis', alpha=0.6) plt.axis('off') def plot_all_visualizations(model, image_path): """Main function to run the model and plot all features.""" try: outputs = model.get_visualization_outputs(image_path) except Exception as e: print(f"Error during model inference/loading: {e}") print("Ensure you have all required Hugging Face models and libraries installed.") return original_image = outputs['original_image'] fig, axes = plt.subplots(1, 4, figsize=(20, 6)) fig.suptitle("Building Feature Visualizations", fontsize=16) # 1. Original Image axes[0].imshow(original_image) axes[0].set_title("Original Image") axes[0].axis('off') # # 2. CLIP Visualization # plt.sca(axes[1]) # visualize_clip_attention(outputs['clip_attentions'], original_image) # 3. SegFormer Visualization plt.sca(axes[2]) visualize_single_feature_map(outputs['segformer_features'], original_image, "SegFormer: Semantic Features") # 4. DPT Visualization plt.sca(axes[3]) visualize_single_feature_map(outputs['dpt_features'], original_image, "DPT: Depth/Structural Features") plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show()