FIBO_WEBAPP / app.py
ember9327's picture
Update app.py
b26bb6b verified
#!/usr/bin/env python
"""Gradio demo for the GAIA prompt and image generation pipeline."""
from __future__ import annotations
import functools
import gc
import json
import logging
import os
import textwrap
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import gradio as gr
import torch
from PIL import Image
from src.gaia_inference.inference import create_pipeline
from src.gaia_inference.inference import run as run_pipeline
from src.gaia_inference.json_to_prompt import (
DEFAULT_SAMPLING,
SUPPORTED_TASKS,
get_json_prompt,
load_engine,
)
LOGGER = logging.getLogger(__name__)
TASK_LABEL_TO_KEY = {label: key for key, label in SUPPORTED_TASKS.items()}
DEFAULT_TASK_LABEL = SUPPORTED_TASKS["inspire"]
TASK_CHOICES = list(SUPPORTED_TASKS.values())
DEFAULT_VLM_MODEL = "briaai/vlm-processor"
DEFAULT_PIPELINE_NAME = "briaai/GAIA-Alpha"
DEFAULT_RESOLUTION = "1024 1024"
DEFAULT_GUIDANCE_SCALE = 5.0
DEFAULT_STEPS = 40
DEFAULT_SEED = -1
DEFAULT_NEGATIVE_PROMPT = ""
RESOLUTIONS_WH = [
"832 1248",
"896 1152",
"960 1088",
"1024 1024",
"1088 960",
"1152 896",
"1216 832",
"1280 800",
"1344 768",
]
ROOT_DIR = Path(__file__).resolve().parents[2]
ASSETS_DIR = ROOT_DIR / "assets"
DEFAULT_PROMPT_PATH = ROOT_DIR / "default_json_caption.json"
try:
REFINED_PROMPT_EXAMPLE = DEFAULT_PROMPT_PATH.read_text()
except FileNotFoundError:
REFINED_PROMPT_EXAMPLE = ""
USAGE_EXAMPLES = [
[
SUPPORTED_TASKS["generate"],
None,
"a dog playing in the park",
"",
"",
DEFAULT_SAMPLING.temperature,
DEFAULT_SAMPLING.top_p,
DEFAULT_SAMPLING.max_tokens,
DEFAULT_RESOLUTION,
DEFAULT_STEPS,
DEFAULT_GUIDANCE_SCALE,
1,
DEFAULT_NEGATIVE_PROMPT,
],
[
SUPPORTED_TASKS["inspire"],
str((ASSETS_DIR / "zebra_balloons.jpeg").resolve()),
"",
"",
"",
DEFAULT_SAMPLING.temperature,
DEFAULT_SAMPLING.top_p,
DEFAULT_SAMPLING.max_tokens,
DEFAULT_RESOLUTION,
DEFAULT_STEPS,
DEFAULT_GUIDANCE_SCALE,
1,
DEFAULT_NEGATIVE_PROMPT,
],
[
SUPPORTED_TASKS["refine"],
None,
"",
REFINED_PROMPT_EXAMPLE,
"change the zebra to an elephant",
DEFAULT_SAMPLING.temperature,
DEFAULT_SAMPLING.top_p,
DEFAULT_SAMPLING.max_tokens,
DEFAULT_RESOLUTION,
DEFAULT_STEPS,
DEFAULT_GUIDANCE_SCALE,
1,
DEFAULT_NEGATIVE_PROMPT,
],
]
def _current_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
# def get_engine(model_name: str = DEFAULT_VLM_MODEL):
@functools.lru_cache(maxsize=2)
def _load_pipeline(pipeline_name: str, device: str):
return create_pipeline(pipeline_name=pipeline_name, device=device)
def get_pipeline(pipeline_name: str = DEFAULT_PIPELINE_NAME):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for image generation.")
return _load_pipeline(pipeline_name, "cuda")
def _format_prompt_text(raw_prompt: str) -> Tuple[str, Dict[str, Any]]:
try:
prompt_dict = json.loads(raw_prompt)
except json.JSONDecodeError as exc:
LOGGER.exception("Model returned invalid JSON prompt.")
raise gr.Error("The VLM returned invalid JSON. Please try again.") from exc
formatted = json.dumps(prompt_dict, indent=2)
return formatted, prompt_dict
def _ensure_task_key(task_value: str) -> str:
if task_value in SUPPORTED_TASKS:
return task_value
task_key = TASK_LABEL_TO_KEY.get(task_value)
if task_key is None:
valid = ", ".join(TASK_CHOICES)
raise gr.Error(f"Unsupported task selection '{task_value}'. Valid options: {valid}.")
return task_key
@torch.inference_mode()
def _generate_prompt(
task: str,
image_value: Optional[Image.Image],
generate_value: Optional[str],
refine_prompt: Optional[str],
refine_instruction: Optional[str],
temperature_value: float,
top_p_value: float,
max_tokens_value: int,
model_name: str = DEFAULT_VLM_MODEL,
) -> Tuple[str, str, Dict[str, Any]]:
task_key = _ensure_task_key(task)
engine = load_engine(model_name=model_name)
engine.model.to("cuda")
# engine = get_engine(model_name=model_name)
# device = _current_device()
# moved_to_cuda = torch.cuda.is_available() and device == "cuda"
generation = None
try:
# if moved_to_cuda:
# engine.to(device)
generation = get_json_prompt(
task=task_key,
engine=engine,
image=image_value,
prompt=generate_value,
structured_prompt=refine_prompt,
editing_instructions=refine_instruction,
temperature=float(temperature_value),
top_p=float(top_p_value),
max_tokens=int(max_tokens_value),
)
except ValueError as exc:
raise gr.Error(str(exc)) from exc
except Exception as exc:
LOGGER.exception("Unexpected error while creating JSON prompt.")
raise gr.Error("Failed to create a JSON prompt. Check the logs for details.") from exc
finally:
del engine
gc.collect()
# if moved_to_cuda:
torch.cuda.synchronize()
torch.cuda.empty_cache()
if generation is None:
raise gr.Error("Failed to create a JSON prompt.")
formatted_prompt, prompt_dict = _format_prompt_text(generation.prompt)
latency_report = generation.latency_report()
return formatted_prompt, latency_report, prompt_dict
def _parse_resolution(raw_value: str) -> Tuple[int, int]:
normalised = raw_value.replace(",", " ").replace("x", " ")
parts = [part for part in normalised.split() if part]
if len(parts) != 2:
raise gr.Error("Resolution must contain exactly two integers, e.g. '1024 1024'.")
try:
width, height = (int(parts[0]), int(parts[1]))
except ValueError as exc:
raise gr.Error("Resolution values must be integers.") from exc
if width <= 0 or height <= 0:
raise gr.Error("Resolution values must be positive.")
return width, height
def _prepare_negative_prompt(raw_value: Optional[str]):
text = (raw_value or "").strip()
if not text:
return ""
try:
return json.loads(text)
except json.JSONDecodeError:
return text
def _run_image_generation(
prompt_data: Dict[str, Any],
resolution_value: str,
steps_value: int,
guidance_value: float,
seed_value: Optional[float],
negative_prompt_value: Optional[str],
pipeline_name: str = DEFAULT_PIPELINE_NAME,
) -> Tuple[str, Image.Image]:
if not torch.cuda.is_available():
raise gr.Error("CUDA is required for image generation.")
width, height = _parse_resolution(resolution_value)
negative_prompt_payload = _prepare_negative_prompt(negative_prompt_value)
seed = DEFAULT_SEED if seed_value is None else int(seed_value)
try:
pipeline = get_pipeline(pipeline_name=pipeline_name)
except RuntimeError as exc:
raise gr.Error(str(exc)) from exc
start = time.perf_counter()
try:
image = run_pipeline(
pipeline=pipeline,
json_prompt=prompt_data,
negative_prompt=negative_prompt_payload,
width=width,
height=height,
seed=seed,
num_steps=int(steps_value),
guidance_scale=float(guidance_value),
)
except Exception as exc:
LOGGER.exception("Failed to generate image.")
raise gr.Error("Image generation failed. Check the logs for details.") from exc
elapsed = time.perf_counter() - start
status = f"Image generation time: {elapsed:.2f}s at {width}x{height}"
return status, image
def _toggle_visibility(task_name: str):
task_key = _ensure_task_key(task_name)
return [
gr.update(visible=task_key == "inspire"),
gr.update(visible=task_key == "generate"),
gr.update(visible=task_key == "refine"),
]
def _clear_inputs():
return (
None,
"",
"",
"",
DEFAULT_SAMPLING.temperature,
DEFAULT_SAMPLING.top_p,
DEFAULT_SAMPLING.max_tokens,
"",
"",
None,
"",
None,
gr.update(visible=False),
DEFAULT_RESOLUTION,
DEFAULT_STEPS,
DEFAULT_GUIDANCE_SCALE,
DEFAULT_SEED,
DEFAULT_NEGATIVE_PROMPT,
)
@torch.inference_mode()
def create_json_prompt(
task: str,
image_value: Optional[Image.Image],
generate_value: Optional[str],
refine_prompt: Optional[str],
refine_instruction: Optional[str],
temperature_value: float,
top_p_value: float,
max_tokens_value: int,
):
formatted_prompt, latency_report, prompt_dict = _generate_prompt(
task=task,
image_value=image_value,
generate_value=generate_value,
refine_prompt=refine_prompt,
refine_instruction=refine_instruction,
temperature_value=temperature_value,
top_p_value=top_p_value,
max_tokens_value=max_tokens_value,
)
return (
formatted_prompt,
latency_report,
prompt_dict,
"",
None,
gr.update(visible=True),
)
def generate_image_from_state(
prompt_state: Optional[Dict[str, Any]],
resolution_value: str,
steps_value: int,
guidance_value: float,
seed_value: Optional[float],
negative_prompt_value: Optional[str],
):
if not prompt_state:
raise gr.Error("Create a JSON prompt first.")
return _run_image_generation(
prompt_data=prompt_state,
resolution_value=resolution_value,
steps_value=steps_value,
guidance_value=guidance_value,
seed_value=seed_value,
negative_prompt_value=negative_prompt_value,
)
def run_full_pipeline(
task: str,
image_value: Optional[Image.Image],
generate_value: Optional[str],
refine_prompt: Optional[str],
refine_instruction: Optional[str],
temperature_value: float,
top_p_value: float,
max_tokens_value: int,
resolution_value: str,
steps_value: int,
guidance_value: float,
seed_value: Optional[float],
negative_prompt_value: Optional[str],
):
task_key = _ensure_task_key(task)
formatted_prompt, latency_report, prompt_dict = _generate_prompt(
task=task_key,
image_value=image_value,
generate_value=generate_value,
refine_prompt=refine_prompt,
refine_instruction=refine_instruction,
temperature_value=temperature_value,
top_p_value=top_p_value,
max_tokens_value=max_tokens_value,
)
status, image = _run_image_generation(
prompt_data=prompt_dict,
resolution_value=resolution_value,
steps_value=steps_value,
guidance_value=guidance_value,
seed_value=seed_value,
negative_prompt_value=negative_prompt_value,
)
return (
formatted_prompt,
latency_report,
prompt_dict,
status,
image,
gr.update(visible=True),
)
def build_demo() -> gr.Blocks:
hero_css = textwrap.dedent(
"""
.hero-row {
justify-content: center;
gap: 0.5rem;
}
.hero-item {
align-items: center;
display: flex;
flex-direction: column;
gap: 0.25rem;
}
.hero-item .gr-image {
max-width: 512px;
}
.hero-image img {
height: 512px !important;
width: 512px !important;
object-fit: cover;
}
.hero-caption {
text-align: center;
width: 100%;
margin: 0;
}
"""
)
with gr.Blocks(title="GAIA Inference Demo", css=hero_css) as demo:
hero_markdown = textwrap.dedent(
"""
# GAIA Prompt & Image Generation
by [Bria.AI](https://bria.ai)
To access via API: [TODO](TODO).
Choose a mode to craft a structured JSON prompt and optionally render an image.
"""
)
gr.Markdown(hero_markdown)
hero_images = [
(ASSETS_DIR / "zebra_balloons.jpeg", "Zebra with balloons"),
(ASSETS_DIR / "face_portrait.jpeg", "Face portrait"),
]
with gr.Row(equal_height=True, elem_classes=["hero-row"]):
for image_path, caption in hero_images:
with gr.Column(scale=0, min_width=512, elem_classes=["hero-item"]):
gr.Image(
value=str(image_path),
type="filepath",
show_label=False,
interactive=False,
elem_classes=["hero-image"],
height=512,
width=512,
)
gr.Markdown(caption, elem_classes=["hero-caption"])
task = gr.Radio(
choices=TASK_CHOICES,
label="Task",
value=DEFAULT_TASK_LABEL,
interactive=True,
info="Choose what you want the model to do.",
)
with gr.Row():
with gr.Column(scale=1, min_width=320):
inspire_group = gr.Group(visible=True)
with inspire_group:
inspire_image = gr.Image(
label="Reference image",
type="pil",
image_mode="RGB",
)
generate_group = gr.Group(visible=False)
with generate_group:
generate_prompt = gr.Textbox(
label="Short prompt",
placeholder="e.g., cyberpunk city at sunrise",
lines=3,
)
refine_group = gr.Group(visible=False)
with refine_group:
refine_input = gr.TextArea(
label="Existing structured prompt",
placeholder="Paste the current structured prompt here.",
lines=12,
)
refine_edits = gr.TextArea(
label="Editing instructions",
placeholder="Describe the changes you want. One instruction per line works well.",
lines=6,
)
with gr.Accordion("additional settings", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=1.2,
value=DEFAULT_SAMPLING.temperature,
step=0.05,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=DEFAULT_SAMPLING.top_p,
step=0.05,
label="Top-p",
)
max_tokens = gr.Slider(
minimum=64,
maximum=4096,
value=DEFAULT_SAMPLING.max_tokens,
step=64,
label="Max tokens",
)
with gr.Column(scale=1, min_width=320):
create_button = gr.Button("Create JSON prompt", variant="primary")
generate_button = gr.Button("Generate image", variant="secondary", visible=False)
full_pipeline_button = gr.Button("Run full pipeline")
clear_button = gr.Button("Clear inputs")
with gr.Accordion("image generation settings", open=False):
resolution = gr.Dropdown(
choices=RESOLUTIONS_WH,
value=DEFAULT_RESOLUTION,
label="Resolution (W H)",
)
steps = gr.Slider(
minimum=10,
maximum=150,
step=1,
value=DEFAULT_STEPS,
label="Steps",
)
guidance = gr.Slider(
minimum=0.1,
maximum=20.0,
step=0.1,
value=DEFAULT_GUIDANCE_SCALE,
label="Guidance scale",
)
seed = gr.Number(
value=DEFAULT_SEED,
precision=0,
label="Seed (-1 for random)",
)
negative_prompt = gr.TextArea(
label="Negative prompt (JSON)",
placeholder='Optional JSON string, e.g. ""',
lines=4,
value=DEFAULT_NEGATIVE_PROMPT,
)
output = gr.TextArea(
label="Generated JSON prompt",
lines=18,
interactive=False,
)
latency = gr.Markdown("")
pipeline_status = gr.Markdown("")
result_image = gr.Image(label="Generated image", type="pil")
prompt_state = gr.State()
task.change(
fn=_toggle_visibility,
inputs=task,
outputs=[inspire_group, generate_group, refine_group],
)
clear_button.click(
fn=_clear_inputs,
inputs=[],
outputs=[
inspire_image,
generate_prompt,
refine_input,
refine_edits,
temperature,
top_p,
max_tokens,
output,
latency,
prompt_state,
pipeline_status,
result_image,
generate_button,
resolution,
steps,
guidance,
seed,
negative_prompt,
],
)
create_button.click(
fn=create_json_prompt,
inputs=[
task,
inspire_image,
generate_prompt,
refine_input,
refine_edits,
temperature,
top_p,
max_tokens,
],
outputs=[
output,
latency,
prompt_state,
pipeline_status,
result_image,
generate_button,
],
)
generate_button.click(
fn=generate_image_from_state,
inputs=[
prompt_state,
resolution,
steps,
guidance,
seed,
negative_prompt,
],
outputs=[
pipeline_status,
result_image,
],
)
full_pipeline_button.click(
fn=run_full_pipeline,
inputs=[
task,
inspire_image,
generate_prompt,
refine_input,
refine_edits,
temperature,
top_p,
max_tokens,
resolution,
steps,
guidance,
seed,
negative_prompt,
],
outputs=[
output,
latency,
prompt_state,
pipeline_status,
result_image,
generate_button,
],
)
gr.Examples(
label="Usage Examples",
examples=USAGE_EXAMPLES,
inputs=[
task,
inspire_image,
generate_prompt,
refine_input,
refine_edits,
temperature,
top_p,
max_tokens,
resolution,
steps,
guidance,
seed,
negative_prompt,
],
outputs=[
output,
latency,
prompt_state,
pipeline_status,
result_image,
generate_button,
],
fn=run_full_pipeline,
)
return demo
logging.basicConfig(level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO))
if __name__ == "__main__":
demo = build_demo()
demo.queue().launch()