import gradio as gr # --- Local image paths --- CAT_WITHOUT_GLASSES_PATH = "cat_without_glasses.png" CAT_WITH_GLASSES_PATH = "cat_with_glasses.png" UNICORN_WITH_HORN_PATH = "unicorn_with_horn.png" UNICORN_NO_HORN_PATH = "unicorn_without_horn.png" # --- Model / head info --- NUM_HEADS = 8 # SDXL UNet attention heads per layer LAYER_CHOICES = [ 'unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2', 'unet.down_blocks.0.attentions.1.transformer_blocks.0.attn2', 'unet.down_blocks.1.attentions.0.transformer_blocks.0.attn2', 'unet.down_blocks.1.attentions.1.transformer_blocks.0.attn2', 'unet.down_blocks.2.attentions.0.transformer_blocks.0.attn2', 'unet.down_blocks.2.attentions.1.transformer_blocks.0.attn2', 'unet.mid_block.attentions.0.transformer_blocks.0.attn2', 'unet.up_blocks.1.attentions.0.transformer_blocks.0.attn2', 'unet.up_blocks.1.attentions.1.transformer_blocks.0.attn2', 'unet.up_blocks.1.attentions.2.transformer_blocks.0.attn2', 'unet.up_blocks.2.attentions.0.transformer_blocks.0.attn2', 'unet.up_blocks.2.attentions.1.transformer_blocks.0.attn2', 'unet.up_blocks.2.attentions.2.transformer_blocks.0.attn2', 'unet.up_blocks.3.attentions.0.transformer_blocks.0.attn2', 'unet.up_blocks.3.attentions.1.transformer_blocks.0.attn2', 'unet.up_blocks.3.attentions.2.transformer_blocks.0.attn2' ] HEAD_CHOICES = [f"head_{i}" for i in range(NUM_HEADS)] # --- Callbacks --- def steer_spectacles(strength: int): """ Simple placeholder: - For strength ~0, show no glasses on both sides. - For strength > 0, show original on the left, glasses on the right. """ print(f"Steering strength was {strength}") if strength == 0: print(f"Returning {CAT_WITHOUT_GLASSES_PATH} twice") return CAT_WITHOUT_GLASSES_PATH, CAT_WITHOUT_GLASSES_PATH else: CAT_WITH_GLASSES_PATH = f"./cat_steering/Cat_step_{strength}.png.png" print(f"Returning {CAT_WITH_GLASSES_PATH}") return CAT_WITHOUT_GLASSES_PATH, CAT_WITH_GLASSES_PATH def run_unicorn_ablation(selected_layer, selected_heads): """ Given selected head labels (e.g., ["head_3", "head_17"]), return: - Unicorn with horn (original) - Unicorn without horn (example ablation outcome) Enforce max of 3 heads. """ # layer = selected_layer.split('blocks.')[1].split('.attentions')[0] if selected_heads is None or len(selected_heads) ==0: selected_heads = [] return UNICORN_WITH_HORN_PATH if len(selected_heads) > 3: selected_heads = selected_heads[:3] selected_heads.sort() if len(selected_heads) ==1: path = f"unicorn_steering/single_heads/{selected_layer}_h{selected_heads[0].replace('head_', '')}" elif len(selected_heads) ==2: path = f"unicorn_steering/head_pairs/{selected_layer}_h{selected_heads[0].replace('head_', '')}_h{selected_heads[1].replace('head_', '')}" elif len(selected_heads) ==3: path = f"unicorn_steering/head_triples/{selected_layer}_h{selected_heads[0].replace('head_', '')}_h{selected_heads[1].replace('head_', '')}_h{selected_heads[2].replace('head_', '')}" path = path.replace('.', '_') + ".png" UNICORN_NO_HORN_PATH = path print(f"Unicorn no horn path was {UNICORN_NO_HORN_PATH}") # In a real experiment you'd use `selected_heads` to ablate SDXL heads. return UNICORN_NO_HORN_PATH with gr.Blocks() as demo: # Global CSS, including 😼 slider thumb gr.HTML(""" """) # -------- 1. Cat Steering (CAA) -------- gr.HTML("""

Cat Steering Console 😼

We steer a normal cat image using contrastive activation addition (CAA): nudging hidden activations along a learned “wear spectacles” direction while keeping other visual features as stable as possible.

""") with gr.Group(): with gr.Row(): cat_left = gr.Image( value=CAT_WITHOUT_GLASSES_PATH, label="Original cat (no glasses)", interactive=False, show_label=True, elem_classes=["img-card"], ) cat_right = gr.Image( value=CAT_WITH_GLASSES_PATH, label="Steered cat", interactive=False, show_label=True, elem_classes=["img-card"], ) steer_slider = gr.Slider( minimum=0, maximum=35, value=35, step=5, label="Steer 😼 (CAA strength towards glasses)", info="Connect this to your actual CAA steering pipeline.", elem_id="cat_steer_slider", ) steer_slider.input( fn=steer_spectacles, inputs=steer_slider, outputs=[cat_left, cat_right] ) # -------- Transition text -------- gr.Markdown( "### From steering to ablation\n" "Below, we move from **additive CAA steering** to **structured ablation**. " "Instead of pushing along a spectacles direction, we toggle specific SDXL attention " "heads on/off and interpret the unicorn without a horn as an ablation outcome." ) # -------- 2. Unicorn Head Ablation -------- gr.HTML("""

Unicorn Head Ablation (SDXL)

Choose up to three attention heads (out of 64) to ablate. In a real SDXL experiment, those heads would be zeroed; here we show a unicorn with its horn intact alongside an example where the horn is removed by ablation.

""") with gr.Group(): with gr.Row(): unicorn_original = gr.Image( value=UNICORN_WITH_HORN_PATH, label="Original unicorn (all heads active)", interactive=False, show_label=True, elem_classes=["img-card"], ) unicorn_ablated = gr.Image( value=UNICORN_NO_HORN_PATH, label="Ablated unicorn", interactive=False, show_label=True, elem_classes=["img-card"], ) layer_selector = gr.Dropdown( choices=LAYER_CHOICES, multiselect=False, value="unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2", label="Attention layer to ablate", info="Select an attention layer to ablate.", ) head_selector = gr.Dropdown( choices=HEAD_CHOICES, multiselect=True, value=["head_0", "head_1"], label="Attention heads to ablate (max 3)", info="Select up to three head indices (0-7). In this demo, images are fixed placeholders.", ) head_selector.input( fn=run_unicorn_ablation, inputs=[layer_selector, head_selector], outputs=[unicorn_ablated] ) demo.launch()