#!/usr/bin/env python """Gradio demo for the GAIA prompt and image generation pipeline.""" from __future__ import annotations import functools import gc import json import logging import os import textwrap import time from pathlib import Path from typing import Any, Dict, Optional, Tuple import gradio as gr import torch from PIL import Image from src.gaia_inference.inference import create_pipeline from src.gaia_inference.inference import run as run_pipeline from src.gaia_inference.json_to_prompt import ( DEFAULT_SAMPLING, SUPPORTED_TASKS, get_json_prompt, load_engine, ) LOGGER = logging.getLogger(__name__) TASK_LABEL_TO_KEY = {label: key for key, label in SUPPORTED_TASKS.items()} DEFAULT_TASK_LABEL = SUPPORTED_TASKS["inspire"] TASK_CHOICES = list(SUPPORTED_TASKS.values()) DEFAULT_VLM_MODEL = "briaai/vlm-processor" DEFAULT_PIPELINE_NAME = "briaai/GAIA-Alpha" DEFAULT_RESOLUTION = "1024 1024" DEFAULT_GUIDANCE_SCALE = 5.0 DEFAULT_STEPS = 40 DEFAULT_SEED = -1 DEFAULT_NEGATIVE_PROMPT = "" RESOLUTIONS_WH = [ "832 1248", "896 1152", "960 1088", "1024 1024", "1088 960", "1152 896", "1216 832", "1280 800", "1344 768", ] ROOT_DIR = Path(__file__).resolve().parents[2] ASSETS_DIR = ROOT_DIR / "assets" DEFAULT_PROMPT_PATH = ROOT_DIR / "default_json_caption.json" try: REFINED_PROMPT_EXAMPLE = DEFAULT_PROMPT_PATH.read_text() except FileNotFoundError: REFINED_PROMPT_EXAMPLE = "" USAGE_EXAMPLES = [ [ SUPPORTED_TASKS["generate"], None, "a dog playing in the park", "", "", DEFAULT_SAMPLING.temperature, DEFAULT_SAMPLING.top_p, DEFAULT_SAMPLING.max_tokens, DEFAULT_RESOLUTION, DEFAULT_STEPS, DEFAULT_GUIDANCE_SCALE, 1, DEFAULT_NEGATIVE_PROMPT, ], [ SUPPORTED_TASKS["inspire"], str((ASSETS_DIR / "zebra_balloons.jpeg").resolve()), "", "", "", DEFAULT_SAMPLING.temperature, DEFAULT_SAMPLING.top_p, DEFAULT_SAMPLING.max_tokens, DEFAULT_RESOLUTION, DEFAULT_STEPS, DEFAULT_GUIDANCE_SCALE, 1, DEFAULT_NEGATIVE_PROMPT, ], [ SUPPORTED_TASKS["refine"], None, "", REFINED_PROMPT_EXAMPLE, "change the zebra to an elephant", DEFAULT_SAMPLING.temperature, DEFAULT_SAMPLING.top_p, DEFAULT_SAMPLING.max_tokens, DEFAULT_RESOLUTION, DEFAULT_STEPS, DEFAULT_GUIDANCE_SCALE, 1, DEFAULT_NEGATIVE_PROMPT, ], ] def _current_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" # def get_engine(model_name: str = DEFAULT_VLM_MODEL): @functools.lru_cache(maxsize=2) def _load_pipeline(pipeline_name: str, device: str): return create_pipeline(pipeline_name=pipeline_name, device=device) def get_pipeline(pipeline_name: str = DEFAULT_PIPELINE_NAME): if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for image generation.") return _load_pipeline(pipeline_name, "cuda") def _format_prompt_text(raw_prompt: str) -> Tuple[str, Dict[str, Any]]: try: prompt_dict = json.loads(raw_prompt) except json.JSONDecodeError as exc: LOGGER.exception("Model returned invalid JSON prompt.") raise gr.Error("The VLM returned invalid JSON. Please try again.") from exc formatted = json.dumps(prompt_dict, indent=2) return formatted, prompt_dict def _ensure_task_key(task_value: str) -> str: if task_value in SUPPORTED_TASKS: return task_value task_key = TASK_LABEL_TO_KEY.get(task_value) if task_key is None: valid = ", ".join(TASK_CHOICES) raise gr.Error(f"Unsupported task selection '{task_value}'. Valid options: {valid}.") return task_key @torch.inference_mode() def _generate_prompt( task: str, image_value: Optional[Image.Image], generate_value: Optional[str], refine_prompt: Optional[str], refine_instruction: Optional[str], temperature_value: float, top_p_value: float, max_tokens_value: int, model_name: str = DEFAULT_VLM_MODEL, ) -> Tuple[str, str, Dict[str, Any]]: task_key = _ensure_task_key(task) engine = load_engine(model_name=model_name) engine.model.to("cuda") # engine = get_engine(model_name=model_name) # device = _current_device() # moved_to_cuda = torch.cuda.is_available() and device == "cuda" generation = None try: # if moved_to_cuda: # engine.to(device) generation = get_json_prompt( task=task_key, engine=engine, image=image_value, prompt=generate_value, structured_prompt=refine_prompt, editing_instructions=refine_instruction, temperature=float(temperature_value), top_p=float(top_p_value), max_tokens=int(max_tokens_value), ) except ValueError as exc: raise gr.Error(str(exc)) from exc except Exception as exc: LOGGER.exception("Unexpected error while creating JSON prompt.") raise gr.Error("Failed to create a JSON prompt. Check the logs for details.") from exc finally: del engine gc.collect() # if moved_to_cuda: torch.cuda.synchronize() torch.cuda.empty_cache() if generation is None: raise gr.Error("Failed to create a JSON prompt.") formatted_prompt, prompt_dict = _format_prompt_text(generation.prompt) latency_report = generation.latency_report() return formatted_prompt, latency_report, prompt_dict def _parse_resolution(raw_value: str) -> Tuple[int, int]: normalised = raw_value.replace(",", " ").replace("x", " ") parts = [part for part in normalised.split() if part] if len(parts) != 2: raise gr.Error("Resolution must contain exactly two integers, e.g. '1024 1024'.") try: width, height = (int(parts[0]), int(parts[1])) except ValueError as exc: raise gr.Error("Resolution values must be integers.") from exc if width <= 0 or height <= 0: raise gr.Error("Resolution values must be positive.") return width, height def _prepare_negative_prompt(raw_value: Optional[str]): text = (raw_value or "").strip() if not text: return "" try: return json.loads(text) except json.JSONDecodeError: return text def _run_image_generation( prompt_data: Dict[str, Any], resolution_value: str, steps_value: int, guidance_value: float, seed_value: Optional[float], negative_prompt_value: Optional[str], pipeline_name: str = DEFAULT_PIPELINE_NAME, ) -> Tuple[str, Image.Image]: if not torch.cuda.is_available(): raise gr.Error("CUDA is required for image generation.") width, height = _parse_resolution(resolution_value) negative_prompt_payload = _prepare_negative_prompt(negative_prompt_value) seed = DEFAULT_SEED if seed_value is None else int(seed_value) try: pipeline = get_pipeline(pipeline_name=pipeline_name) except RuntimeError as exc: raise gr.Error(str(exc)) from exc start = time.perf_counter() try: image = run_pipeline( pipeline=pipeline, json_prompt=prompt_data, negative_prompt=negative_prompt_payload, width=width, height=height, seed=seed, num_steps=int(steps_value), guidance_scale=float(guidance_value), ) except Exception as exc: LOGGER.exception("Failed to generate image.") raise gr.Error("Image generation failed. Check the logs for details.") from exc elapsed = time.perf_counter() - start status = f"Image generation time: {elapsed:.2f}s at {width}x{height}" return status, image def _toggle_visibility(task_name: str): task_key = _ensure_task_key(task_name) return [ gr.update(visible=task_key == "inspire"), gr.update(visible=task_key == "generate"), gr.update(visible=task_key == "refine"), ] def _clear_inputs(): return ( None, "", "", "", DEFAULT_SAMPLING.temperature, DEFAULT_SAMPLING.top_p, DEFAULT_SAMPLING.max_tokens, "", "", None, "", None, gr.update(visible=False), DEFAULT_RESOLUTION, DEFAULT_STEPS, DEFAULT_GUIDANCE_SCALE, DEFAULT_SEED, DEFAULT_NEGATIVE_PROMPT, ) @torch.inference_mode() def create_json_prompt( task: str, image_value: Optional[Image.Image], generate_value: Optional[str], refine_prompt: Optional[str], refine_instruction: Optional[str], temperature_value: float, top_p_value: float, max_tokens_value: int, ): formatted_prompt, latency_report, prompt_dict = _generate_prompt( task=task, image_value=image_value, generate_value=generate_value, refine_prompt=refine_prompt, refine_instruction=refine_instruction, temperature_value=temperature_value, top_p_value=top_p_value, max_tokens_value=max_tokens_value, ) return ( formatted_prompt, latency_report, prompt_dict, "", None, gr.update(visible=True), ) def generate_image_from_state( prompt_state: Optional[Dict[str, Any]], resolution_value: str, steps_value: int, guidance_value: float, seed_value: Optional[float], negative_prompt_value: Optional[str], ): if not prompt_state: raise gr.Error("Create a JSON prompt first.") return _run_image_generation( prompt_data=prompt_state, resolution_value=resolution_value, steps_value=steps_value, guidance_value=guidance_value, seed_value=seed_value, negative_prompt_value=negative_prompt_value, ) def run_full_pipeline( task: str, image_value: Optional[Image.Image], generate_value: Optional[str], refine_prompt: Optional[str], refine_instruction: Optional[str], temperature_value: float, top_p_value: float, max_tokens_value: int, resolution_value: str, steps_value: int, guidance_value: float, seed_value: Optional[float], negative_prompt_value: Optional[str], ): task_key = _ensure_task_key(task) formatted_prompt, latency_report, prompt_dict = _generate_prompt( task=task_key, image_value=image_value, generate_value=generate_value, refine_prompt=refine_prompt, refine_instruction=refine_instruction, temperature_value=temperature_value, top_p_value=top_p_value, max_tokens_value=max_tokens_value, ) status, image = _run_image_generation( prompt_data=prompt_dict, resolution_value=resolution_value, steps_value=steps_value, guidance_value=guidance_value, seed_value=seed_value, negative_prompt_value=negative_prompt_value, ) return ( formatted_prompt, latency_report, prompt_dict, status, image, gr.update(visible=True), ) def build_demo() -> gr.Blocks: hero_css = textwrap.dedent( """ .hero-row { justify-content: center; gap: 0.5rem; } .hero-item { align-items: center; display: flex; flex-direction: column; gap: 0.25rem; } .hero-item .gr-image { max-width: 512px; } .hero-image img { height: 512px !important; width: 512px !important; object-fit: cover; } .hero-caption { text-align: center; width: 100%; margin: 0; } """ ) with gr.Blocks(title="GAIA Inference Demo", css=hero_css) as demo: hero_markdown = textwrap.dedent( """ # GAIA Prompt & Image Generation by [Bria.AI](https://bria.ai) To access via API: [TODO](TODO). Choose a mode to craft a structured JSON prompt and optionally render an image. """ ) gr.Markdown(hero_markdown) hero_images = [ (ASSETS_DIR / "zebra_balloons.jpeg", "Zebra with balloons"), (ASSETS_DIR / "face_portrait.jpeg", "Face portrait"), ] with gr.Row(equal_height=True, elem_classes=["hero-row"]): for image_path, caption in hero_images: with gr.Column(scale=0, min_width=512, elem_classes=["hero-item"]): gr.Image( value=str(image_path), type="filepath", show_label=False, interactive=False, elem_classes=["hero-image"], height=512, width=512, ) gr.Markdown(caption, elem_classes=["hero-caption"]) task = gr.Radio( choices=TASK_CHOICES, label="Task", value=DEFAULT_TASK_LABEL, interactive=True, info="Choose what you want the model to do.", ) with gr.Row(): with gr.Column(scale=1, min_width=320): inspire_group = gr.Group(visible=True) with inspire_group: inspire_image = gr.Image( label="Reference image", type="pil", image_mode="RGB", ) generate_group = gr.Group(visible=False) with generate_group: generate_prompt = gr.Textbox( label="Short prompt", placeholder="e.g., cyberpunk city at sunrise", lines=3, ) refine_group = gr.Group(visible=False) with refine_group: refine_input = gr.TextArea( label="Existing structured prompt", placeholder="Paste the current structured prompt here.", lines=12, ) refine_edits = gr.TextArea( label="Editing instructions", placeholder="Describe the changes you want. One instruction per line works well.", lines=6, ) with gr.Accordion("additional settings", open=False): temperature = gr.Slider( minimum=0.0, maximum=1.2, value=DEFAULT_SAMPLING.temperature, step=0.05, label="Temperature", ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_SAMPLING.top_p, step=0.05, label="Top-p", ) max_tokens = gr.Slider( minimum=64, maximum=4096, value=DEFAULT_SAMPLING.max_tokens, step=64, label="Max tokens", ) with gr.Column(scale=1, min_width=320): create_button = gr.Button("Create JSON prompt", variant="primary") generate_button = gr.Button("Generate image", variant="secondary", visible=False) full_pipeline_button = gr.Button("Run full pipeline") clear_button = gr.Button("Clear inputs") with gr.Accordion("image generation settings", open=False): resolution = gr.Dropdown( choices=RESOLUTIONS_WH, value=DEFAULT_RESOLUTION, label="Resolution (W H)", ) steps = gr.Slider( minimum=10, maximum=150, step=1, value=DEFAULT_STEPS, label="Steps", ) guidance = gr.Slider( minimum=0.1, maximum=20.0, step=0.1, value=DEFAULT_GUIDANCE_SCALE, label="Guidance scale", ) seed = gr.Number( value=DEFAULT_SEED, precision=0, label="Seed (-1 for random)", ) negative_prompt = gr.TextArea( label="Negative prompt (JSON)", placeholder='Optional JSON string, e.g. ""', lines=4, value=DEFAULT_NEGATIVE_PROMPT, ) output = gr.TextArea( label="Generated JSON prompt", lines=18, interactive=False, ) latency = gr.Markdown("") pipeline_status = gr.Markdown("") result_image = gr.Image(label="Generated image", type="pil") prompt_state = gr.State() task.change( fn=_toggle_visibility, inputs=task, outputs=[inspire_group, generate_group, refine_group], ) clear_button.click( fn=_clear_inputs, inputs=[], outputs=[ inspire_image, generate_prompt, refine_input, refine_edits, temperature, top_p, max_tokens, output, latency, prompt_state, pipeline_status, result_image, generate_button, resolution, steps, guidance, seed, negative_prompt, ], ) create_button.click( fn=create_json_prompt, inputs=[ task, inspire_image, generate_prompt, refine_input, refine_edits, temperature, top_p, max_tokens, ], outputs=[ output, latency, prompt_state, pipeline_status, result_image, generate_button, ], ) generate_button.click( fn=generate_image_from_state, inputs=[ prompt_state, resolution, steps, guidance, seed, negative_prompt, ], outputs=[ pipeline_status, result_image, ], ) full_pipeline_button.click( fn=run_full_pipeline, inputs=[ task, inspire_image, generate_prompt, refine_input, refine_edits, temperature, top_p, max_tokens, resolution, steps, guidance, seed, negative_prompt, ], outputs=[ output, latency, prompt_state, pipeline_status, result_image, generate_button, ], ) gr.Examples( label="Usage Examples", examples=USAGE_EXAMPLES, inputs=[ task, inspire_image, generate_prompt, refine_input, refine_edits, temperature, top_p, max_tokens, resolution, steps, guidance, seed, negative_prompt, ], outputs=[ output, latency, prompt_state, pipeline_status, result_image, generate_button, ], fn=run_full_pipeline, ) return demo logging.basicConfig(level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO)) if __name__ == "__main__": demo = build_demo() demo.queue().launch()