import gradio as gr # --- Local image paths --- CAT_WITHOUT_GLASSES_PATH = "cat_without_glasses.png" CAT_WITH_GLASSES_PATH = "cat_with_glasses.png" UNICORN_WITH_HORN_PATH = "unicorn_with_horn.png" UNICORN_NO_HORN_PATH = "unicorn_without_horn.png" # --- Model / head info --- NUM_HEADS = 8 # SDXL UNet attention heads per layer LAYER_CHOICES = [ 'unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2', 'unet.down_blocks.0.attentions.1.transformer_blocks.0.attn2', 'unet.down_blocks.1.attentions.0.transformer_blocks.0.attn2', 'unet.down_blocks.1.attentions.1.transformer_blocks.0.attn2', 'unet.down_blocks.2.attentions.0.transformer_blocks.0.attn2', 'unet.down_blocks.2.attentions.1.transformer_blocks.0.attn2', 'unet.mid_block.attentions.0.transformer_blocks.0.attn2', 'unet.up_blocks.1.attentions.0.transformer_blocks.0.attn2', 'unet.up_blocks.1.attentions.1.transformer_blocks.0.attn2', 'unet.up_blocks.1.attentions.2.transformer_blocks.0.attn2', 'unet.up_blocks.2.attentions.0.transformer_blocks.0.attn2', 'unet.up_blocks.2.attentions.1.transformer_blocks.0.attn2', 'unet.up_blocks.2.attentions.2.transformer_blocks.0.attn2', 'unet.up_blocks.3.attentions.0.transformer_blocks.0.attn2', 'unet.up_blocks.3.attentions.1.transformer_blocks.0.attn2', 'unet.up_blocks.3.attentions.2.transformer_blocks.0.attn2' ] HEAD_CHOICES = [f"head_{i}" for i in range(NUM_HEADS)] # --- Callbacks --- def steer_spectacles(strength: int): """ Simple placeholder: - For strength ~0, show no glasses on both sides. - For strength > 0, show original on the left, glasses on the right. """ print(f"Steering strength was {strength}") if strength == 0: print(f"Returning {CAT_WITHOUT_GLASSES_PATH} twice") return CAT_WITHOUT_GLASSES_PATH, CAT_WITHOUT_GLASSES_PATH else: CAT_WITH_GLASSES_PATH = f"./cat_steering/Cat_step_{strength}.png.png" print(f"Returning {CAT_WITH_GLASSES_PATH}") return CAT_WITHOUT_GLASSES_PATH, CAT_WITH_GLASSES_PATH def run_unicorn_ablation(selected_layer, selected_heads): """ Given selected head labels (e.g., ["head_3", "head_17"]), return: - Unicorn with horn (original) - Unicorn without horn (example ablation outcome) Enforce max of 3 heads. """ # layer = selected_layer.split('blocks.')[1].split('.attentions')[0] if selected_heads is None or len(selected_heads) ==0: selected_heads = [] return UNICORN_WITH_HORN_PATH if len(selected_heads) > 3: selected_heads = selected_heads[:3] selected_heads.sort() if len(selected_heads) ==1: path = f"unicorn_steering/single_heads/{selected_layer}_h{selected_heads[0].replace('head_', '')}" elif len(selected_heads) ==2: path = f"unicorn_steering/head_pairs/{selected_layer}_h{selected_heads[0].replace('head_', '')}_h{selected_heads[1].replace('head_', '')}" elif len(selected_heads) ==3: path = f"unicorn_steering/head_triples/{selected_layer}_h{selected_heads[0].replace('head_', '')}_h{selected_heads[1].replace('head_', '')}_h{selected_heads[2].replace('head_', '')}" path = path.replace('.', '_') + ".png" UNICORN_NO_HORN_PATH = path print(f"Unicorn no horn path was {UNICORN_NO_HORN_PATH}") # In a real experiment you'd use `selected_heads` to ablate SDXL heads. return UNICORN_NO_HORN_PATH with gr.Blocks() as demo: # Global CSS, including 😼 slider thumb gr.HTML(""" """) # -------- 1. Cat Steering (CAA) -------- gr.HTML("""
We steer a normal cat image using contrastive activation addition (CAA): nudging hidden activations along a learned “wear spectacles” direction while keeping other visual features as stable as possible.
Choose up to three attention heads (out of 64) to ablate. In a real SDXL experiment, those heads would be zeroed; here we show a unicorn with its horn intact alongside an example where the horn is removed by ablation.