Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| from collections import Counter | |
| import math | |
| from gradio import processing_utils | |
| from typing import Optional | |
| import warnings | |
| from datetime import datetime | |
| import torch | |
| from PIL import Image | |
| from diffusers import StableDiffusionInpaintPipeline | |
| from accelerate.utils import set_seed | |
| clevr_all_objects = [ | |
| 'blue metal cube', | |
| 'blue metal cylinder', | |
| 'blue metal sphere', | |
| 'blue rubber cube', | |
| 'blue rubber cylinder', | |
| 'blue rubber sphere', | |
| 'brown metal cube', | |
| 'brown metal cylinder', | |
| 'brown metal sphere', | |
| 'brown rubber cube', | |
| 'brown rubber cylinder', | |
| 'brown rubber sphere', | |
| 'cyan metal cube', | |
| 'cyan metal cylinder', | |
| 'cyan metal sphere', | |
| 'cyan rubber cube', | |
| 'cyan rubber cylinder', | |
| 'cyan rubber sphere', | |
| 'gray metal cube', | |
| 'gray metal cylinder', | |
| 'gray metal sphere', | |
| 'gray rubber cube', | |
| 'gray rubber cylinder', | |
| 'gray rubber sphere', | |
| 'green metal cube', | |
| 'green metal cylinder', | |
| 'green metal sphere', | |
| 'green rubber cube', | |
| 'green rubber cylinder', | |
| 'green rubber sphere', | |
| 'purple metal cube', | |
| 'purple metal cylinder', | |
| 'purple metal sphere', | |
| 'purple rubber cube', | |
| 'purple rubber cylinder', | |
| 'purple rubber sphere', | |
| 'red metal cube', | |
| 'red metal cylinder', | |
| 'red metal sphere', | |
| 'red rubber cube', | |
| 'red rubber cylinder', | |
| 'red rubber sphere', | |
| 'yellow metal cube', | |
| 'yellow metal cylinder', | |
| 'yellow metal sphere', | |
| 'yellow rubber cube', | |
| 'yellow rubber cylinder', | |
| 'yellow rubber sphere' | |
| ] | |
| all_clevr_colors = ['blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow'] | |
| all_clevr_materials = ['metal', 'rubber'] | |
| all_clevr_shapes = ['cube', 'cylinder', 'sphere'] | |
| class Instance: | |
| def __init__(self, capacity = 2): | |
| self.model_type = 'base' | |
| self.loaded_model_list = {} | |
| self.counter = Counter() | |
| self.global_counter = Counter() | |
| self.capacity = capacity | |
| self.loaded_model = None | |
| def _log(self, model_type, batch_size, instruction, phrase_list): | |
| self.counter[model_type] += 1 | |
| self.global_counter[model_type] += 1 | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format( | |
| current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list | |
| )) | |
| def get_model(self): | |
| if self.pipe is None: | |
| self.pipe = self.load_model() | |
| if torch.cuda.is_available(): | |
| self.pipe.to("cuda") | |
| print("Loaded model to GPU") | |
| return self.pipe | |
| def load_model(self, model_id='j-min/IterInpaint-CLEVR'): | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id) | |
| def dummy(images, **kwargs): | |
| return images, False | |
| pipe.safety_checker = dummy | |
| print("Disabled safety checker") | |
| print("Loaded model") | |
| if torch.cuda.is_available(): | |
| pipe.to("cuda") | |
| print("Loaded model to GPU") | |
| # # This command loads the individual model components on GPU on-demand. So, we don't | |
| # # need to explicitly call pipe.to("cuda"). | |
| # pipe.enable_model_cpu_offload() | |
| # # xformers | |
| # pipe.enable_xformers_memory_efficient_attention() | |
| self.pipe = pipe | |
| instance = Instance() | |
| instance.load_model() | |
| from gen_utils import encode_from_custom_annotation, iterinpaint_sample_diffusers | |
| class ImageMask(gr.components.Image): | |
| """ | |
| Sets: source="canvas", tool="sketch" | |
| """ | |
| is_template = True | |
| def __init__(self, **kwargs): | |
| super().__init__(source="upload", tool="sketch", interactive=True, **kwargs) | |
| def preprocess(self, x): | |
| if x is None: | |
| return x | |
| if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict: | |
| decode_image = processing_utils.decode_base64_to_image(x) | |
| width, height = decode_image.size | |
| mask = np.zeros((height, width, 4), dtype=np.uint8) | |
| mask[..., -1] = 255 | |
| mask = self.postprocess(mask) | |
| x = {'image': x, 'mask': mask} | |
| return super().preprocess(x) | |
| class Blocks(gr.Blocks): | |
| def __init__( | |
| self, | |
| theme: str = "default", | |
| analytics_enabled: Optional[bool] = None, | |
| mode: str = "blocks", | |
| title: str = "Gradio", | |
| css: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| self.extra_configs = { | |
| 'thumbnail': kwargs.pop('thumbnail', ''), | |
| 'url': kwargs.pop('url', 'https://gradio.app/'), | |
| 'creator': kwargs.pop('creator', '@teamGradio'), | |
| } | |
| super(Blocks, self).__init__( | |
| theme, analytics_enabled, mode, title, css, **kwargs) | |
| warnings.filterwarnings("ignore") | |
| def get_config_file(self): | |
| config = super(Blocks, self).get_config_file() | |
| for k, v in self.extra_configs.items(): | |
| config[k] = v | |
| return config | |
| def draw_box(boxes=[], texts=[], img=None): | |
| if len(boxes) == 0 and img is None: | |
| return None | |
| if img is None: | |
| img = Image.new('RGB', (512, 512), (255, 255, 255)) | |
| colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] | |
| draw = ImageDraw.Draw(img) | |
| font = ImageFont.truetype("DejaVuSansMono.ttf", size=20) | |
| for bid, box in enumerate(boxes): | |
| draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4) | |
| anno_text = texts[bid] | |
| draw.rectangle([box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4) | |
| draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size*1.2)], anno_text, font=font, fill=(255,255,255)) | |
| return img | |
| def get_concat(ims): | |
| if len(ims) == 1: | |
| n_col = 1 | |
| else: | |
| n_col = 2 | |
| n_row = math.ceil(len(ims) / 2) | |
| dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white") | |
| for i, im in enumerate(ims): | |
| row_id = i // n_col | |
| col_id = i % n_col | |
| dst.paste(im, (im.width * col_id, im.height * row_id)) | |
| return dst | |
| def inference(language_instruction, grounding_texts, boxes, guidance_scale): | |
| # custom_annotations = [ | |
| # {'x': 19, | |
| # 'y': 61, | |
| # 'width': 158, | |
| # 'height': 169, | |
| # 'label': 'blue metal cube'}, | |
| # {'x': 183, | |
| # 'y': 94, | |
| # 'width': 103, | |
| # 'height': 109, | |
| # 'label': 'brown rubber sphere'}, | |
| # ] | |
| # # boxes - normalized -> unnormalized | |
| # boxes = np.array(boxes) * 512 | |
| custom_annotations = [] | |
| for i in range(len(boxes)): | |
| box = boxes[i] | |
| custom_annotations.append({'x': box[0], | |
| 'y': box[1], | |
| 'width': box[2] - box[0], | |
| 'height': box[3] - box[1], | |
| 'label': grounding_texts[i]}) | |
| # # 1) convert xywh to xyxy | |
| # # 2) normalize coordinates | |
| scene = encode_from_custom_annotation(custom_annotations, size=512) | |
| print(scene['box_captions']) | |
| print(scene['boxes_normalized']) | |
| pipe = instance.get_model() | |
| out = iterinpaint_sample_diffusers( | |
| pipe, scene, paste=True, verbose=True, size=512, guidance_scale=guidance_scale) | |
| final_image = out['generated_images'][-1].copy() | |
| # Create Generation GIF | |
| prompts = out['prompts'] | |
| fps = 4 | |
| def create_gif_source_images(images, prompts): | |
| """Create source images for gif | |
| Each frame consists of a intermediate image with a prompt as title. | |
| Don't change size of the original images. | |
| """ | |
| step_images = [] | |
| font = ImageFont.truetype("DejaVuSansMono.ttf", size=20) | |
| for i, img in enumerate(images): | |
| draw = ImageDraw.Draw(img) | |
| draw.text((0, 0), prompts[i], (255, 255, 255), font=font) | |
| step_images.append(img) | |
| return step_images | |
| import imageio | |
| step_images = create_gif_source_images(out['generated_images'], prompts) | |
| print("Number of frames in GIF: {}".format(len(step_images))) | |
| # create temp path | |
| import tempfile | |
| import os | |
| gif_save_path = os.path.join(tempfile.gettempdir(), 'gen.gif') | |
| # create gif | |
| imageio.mimsave(gif_save_path, step_images, fps=fps) | |
| print('GIF saved to {}'.format(gif_save_path)) | |
| out_images = [ | |
| final_image, | |
| gif_save_path | |
| ] | |
| return out_images | |
| def generate(task, language_instruction, grounding_texts, sketch_pad, | |
| alpha_sample, guidance_scale, batch_size, | |
| fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image, | |
| state): | |
| if 'boxes' not in state: | |
| state['boxes'] = [] | |
| boxes = state['boxes'] | |
| grounding_texts = [x.strip() for x in grounding_texts.split(';')] | |
| # assert len(boxes) == len(grounding_texts) | |
| # check if object query is within clevr_all_objects | |
| for grounding_text in grounding_texts: | |
| if grounding_text not in clevr_all_objects: | |
| raise ValueError("""The grounding object {} is not in the CLEVR dataset.""".format(grounding_text)) | |
| if len(boxes) != len(grounding_texts): | |
| if len(boxes) < len(grounding_texts): | |
| raise ValueError("""The number of boxes should be equal to the number of grounding objects. | |
| Number of boxes drawn: {}, number of grounding tokens: {}. | |
| Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts))) | |
| grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts)) | |
| # # normalize boxes | |
| # boxes = (np.asarray(boxes) / 512).tolist() | |
| print('input boxes: ', boxes) | |
| print('input grounding_texts: ', grounding_texts) | |
| print('input language instruction: ', language_instruction) | |
| if fix_seed: | |
| set_seed(rand_seed) | |
| print('seed set to: ', rand_seed) | |
| gen_image, gen_animation = inference( | |
| language_instruction, grounding_texts, boxes, | |
| guidance_scale=guidance_scale, | |
| ) | |
| # for idx, gen_image in enumerate(gen_images): | |
| # if task == 'Grounded Inpainting' and state.get('inpaint_hw', None): | |
| # hw = min(*state['original_image'].shape[:2]) | |
| # gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw) | |
| # gen_image = Image.fromarray(gen_image) | |
| # gen_images[idx] = gen_image | |
| # blank_samples = batch_size % 2 if batch_size > 1 else 0 | |
| # gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \ | |
| # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ | |
| # + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)] | |
| # gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \ | |
| # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ | |
| gen_images = [ | |
| gr.Image.update(value=gen_image, visible=True), | |
| gr.Image.update(value=gen_animation, visible=True) | |
| ] | |
| return gen_images + [state] | |
| def binarize(x): | |
| return (x != 0).astype('uint8') * 255 | |
| def sized_center_crop(img, cropx, cropy): | |
| y, x = img.shape[:2] | |
| startx = x // 2 - (cropx // 2) | |
| starty = y // 2 - (cropy // 2) | |
| return img[starty:starty+cropy, startx:startx+cropx] | |
| def sized_center_fill(img, fill, cropx, cropy): | |
| y, x = img.shape[:2] | |
| startx = x // 2 - (cropx // 2) | |
| starty = y // 2 - (cropy // 2) | |
| img[starty:starty+cropy, startx:startx+cropx] = fill | |
| return img | |
| def sized_center_mask(img, cropx, cropy): | |
| y, x = img.shape[:2] | |
| startx = x // 2 - (cropx // 2) | |
| starty = y // 2 - (cropy // 2) | |
| center_region = img[starty:starty+cropy, startx:startx+cropx].copy() | |
| img = (img * 0.2).astype('uint8') | |
| img[starty:starty+cropy, startx:startx+cropx] = center_region | |
| return img | |
| def center_crop(img, HW=None, tgt_size=(512, 512)): | |
| if HW is None: | |
| H, W = img.shape[:2] | |
| HW = min(H, W) | |
| img = sized_center_crop(img, HW, HW) | |
| img = Image.fromarray(img) | |
| img = img.resize(tgt_size) | |
| return np.array(img) | |
| def draw(task, input, grounding_texts, new_image_trigger, state): | |
| if type(input) == dict: | |
| image = input['image'] | |
| mask = input['mask'] | |
| else: | |
| mask = input | |
| if mask.ndim == 3: | |
| mask = mask[..., 0] | |
| image_scale = 1.0 | |
| # resize trigger | |
| if task == "Grounded Inpainting": | |
| mask_cond = mask.sum() == 0 | |
| # size_cond = mask.shape != (512, 512) | |
| if mask_cond and 'original_image' not in state: | |
| image = Image.fromarray(image) | |
| width, height = image.size | |
| scale = 600 / min(width, height) | |
| image = image.resize((int(width * scale), int(height * scale))) | |
| state['original_image'] = np.array(image).copy() | |
| image_scale = float(height / width) | |
| return [None, new_image_trigger + 1, image_scale, state] | |
| else: | |
| original_image = state['original_image'] | |
| H, W = original_image.shape[:2] | |
| image_scale = float(H / W) | |
| mask = binarize(mask) | |
| if mask.shape != (512, 512): | |
| # assert False, "should not receive any non- 512x512 masks." | |
| if 'original_image' in state and state['original_image'].shape[:2] == mask.shape: | |
| mask = center_crop(mask, state['inpaint_hw']) | |
| image = center_crop(state['original_image'], state['inpaint_hw']) | |
| else: | |
| mask = np.zeros((512, 512), dtype=np.uint8) | |
| # mask = center_crop(mask) | |
| mask = binarize(mask) | |
| if type(mask) != np.ndarray: | |
| mask = np.array(mask) | |
| if mask.sum() == 0 and task != "Grounded Inpainting": | |
| state = {} | |
| if task != 'Grounded Inpainting': | |
| image = None | |
| else: | |
| image = Image.fromarray(image) | |
| if 'boxes' not in state: | |
| state['boxes'] = [] | |
| if 'masks' not in state or len(state['masks']) == 0: | |
| state['masks'] = [] | |
| last_mask = np.zeros_like(mask) | |
| else: | |
| last_mask = state['masks'][-1] | |
| if type(mask) == np.ndarray and mask.size > 1: | |
| diff_mask = mask - last_mask | |
| else: | |
| diff_mask = np.zeros([]) | |
| if diff_mask.sum() > 0: | |
| x1x2 = np.where(diff_mask.max(0) != 0)[0] | |
| y1y2 = np.where(diff_mask.max(1) != 0)[0] | |
| y1, y2 = y1y2.min(), y1y2.max() | |
| x1, x2 = x1x2.min(), x1x2.max() | |
| if (x2 - x1 > 5) and (y2 - y1 > 5): | |
| state['masks'].append(mask.copy()) | |
| state['boxes'].append((x1, y1, x2, y2)) | |
| grounding_texts = [x.strip() for x in grounding_texts.split(';')] | |
| grounding_texts = [x for x in grounding_texts if len(x) > 0] | |
| if len(grounding_texts) < len(state['boxes']): | |
| grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))] | |
| box_image = draw_box(state['boxes'], grounding_texts, image) | |
| if box_image is not None and state.get('inpaint_hw', None): | |
| inpaint_hw = state['inpaint_hw'] | |
| box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw))) | |
| original_image = state['original_image'].copy() | |
| box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw) | |
| return [box_image, new_image_trigger, image_scale, state] | |
| def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False): | |
| if task != 'Grounded Inpainting': | |
| sketch_pad_trigger = sketch_pad_trigger + 1 | |
| blank_samples = batch_size % 2 if batch_size > 1 else 0 | |
| # out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \ | |
| # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \ | |
| # + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)] | |
| out_images = [gr.Image.update(value=None, visible=True) for i in range(1)] \ | |
| + [gr.Image.update(value=None, visible=True) for _ in range(1)] | |
| state = {} | |
| return [None, sketch_pad_trigger, None, 1.0] + out_images + [state] | |
| css = """ | |
| #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img | |
| { | |
| height: var(--height) !important; | |
| max-height: var(--height) !important; | |
| min-height: var(--height) !important; | |
| } | |
| #paper-info a { | |
| color:#008AD7; | |
| text-decoration: none; | |
| } | |
| #paper-info a:hover { | |
| cursor: pointer; | |
| text-decoration: none; | |
| } | |
| """ | |
| rescale_js = """ | |
| function(x) { | |
| const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app'); | |
| let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0; | |
| const image_width = root.querySelector('#img2img_image').clientWidth; | |
| const target_height = parseInt(image_width * image_scale); | |
| document.body.style.setProperty('--height', `${target_height}px`); | |
| root.querySelectorAll('button.justify-center.rounded')[0].style.display='none'; | |
| root.querySelectorAll('button.justify-center.rounded')[1].style.display='none'; | |
| return x; | |
| } | |
| """ | |
| with Blocks( | |
| # css=css, | |
| analytics_enabled=False, | |
| title="IterInpaint demo", | |
| ) as main: | |
| description = """ | |
| <p style="text-align: center; font-weight: bold;"> | |
| <span style="font-size: 28px">IterInpaint CLEVR Demo</span> | |
| <br> | |
| <span style="font-size: 18px" id="paper-info"> | |
| [<a href="https://layoutbench.github.io" target="_blank">Project Page</a>] | |
| [<a href="https://arxiv.org/abs/2304.06671" target="_blank">Paper</a>] | |
| [<a href="https://github.com/j-min/IterInpaint" target="_blank">GitHub</a>] | |
| </span> | |
| </p> | |
| <span style="font-size: 14px"> | |
| <b>IterInpaint</b> is a new baseline for layout-guided image generation. | |
| Unlike previous methods that generate all objects in a single step, IterInpaint decomposes image generation process into multiple steps and uses an inpainting model to update regions step-by-step. | |
| </span> | |
| <br> | |
| <br> | |
| <span style="font-size: 18px" id="instruction"> | |
| Instructions: | |
| </span> | |
| <p> | |
| (1) ⌨️ Enter the object names in <em> Region Captions</em> | |
| <br> | |
| Since the model is trained on <a href="https://cs.stanford.edu/people/jcjohns/clevr/" target="_blank">CLEVR</a> dataset, you can use the object names in the form of <b>"[color] [material] [shape]"</b> (e.g., <em>blue metal sphere</em>): | |
| <br> | |
| <ul> | |
| <li>color: <em><color style="color: red">red</color>, <color style="color: cyan">cyan</color>, <color style="color: green">green</color>, <color style="color: blue">blue</color>, <color style="color: yellow">yellow</color>, <color style="color: purple">purple</color>, <color style="color: brown">brown</color>, <color style="color: gray">gray</color></em></li> | |
| <li>material: <em>metal, rubber</em></li> | |
| <li>shape: <em>cylinder, cube, sphere</em></li> | |
| </ul> | |
| (2) 🖱️ Draw their corresponding bounding boxes one by one using <em> Sketch Pad</em> -- the parsed boxes will be displayed automatically. | |
| <br> | |
| For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/j-min/iterinpaint-CLEVR?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a> | |
| </p> | |
| """ | |
| gr.HTML(description) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| sketch_pad_trigger = gr.Number(value=0, visible=False) | |
| sketch_pad_resize_trigger = gr.Number(value=0, visible=False) | |
| init_white_trigger = gr.Number(value=0, visible=False) | |
| image_scale = gr.Number( | |
| value=0, elem_id="image_scale", visible=False) | |
| new_image_trigger = gr.Number(value=0, visible=False) | |
| # task = gr.Radio( | |
| # choices=["Grounded Generation", 'Grounded Inpainting'], | |
| # type="value", | |
| # value="Grounded Generation", | |
| # label="Task", | |
| # ) | |
| task = gr.State("Grounded Generation") | |
| # language_instruction = gr.Textbox( | |
| # label="Language instruction", | |
| # ) | |
| language_instruction = gr.State("") | |
| grounding_instruction = gr.Textbox( | |
| label=""" | |
| Region Captions (Separated by semicolon) | |
| e.g., "blue metal cube; red rubber cylinder" | |
| """, | |
| ) | |
| with gr.Row(): | |
| sketch_pad = ImageMask( | |
| label="Draw bounding boxes", elem_id="img2img_image") | |
| out_imagebox = gr.Image(type="pil", label="Parsed Layout") | |
| with gr.Row(): | |
| clear_btn = gr.Button(value='Clear') | |
| gen_btn = gr.Button(value='Generate') | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Column(): | |
| # alpha_sample = gr.Slider( | |
| # minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (Ο)") | |
| alpha_sample = gr.State(0.3) | |
| guidance_scale = gr.Slider( | |
| minimum=0, maximum=50, step=0.5, value=4.0, label="Guidance Scale") | |
| # batch_size = gr.Slider( | |
| # minimum=1, maximum=4, step=1, value=2, label="Number of Samples") | |
| # batch_size = gr.Slider( | |
| # minimum=1, maximum=1, step=1, value=1, label="Number of Samples") | |
| batch_size = gr.State(1) | |
| # append_grounding = gr.Checkbox( | |
| # value=True, label="Append grounding instructions to the caption") | |
| append_grounding = gr.State(False) | |
| # use_actual_mask = gr.Checkbox( | |
| # value=False, label="Use actual mask for inpainting", visible=False) | |
| use_actual_mask = gr.State(False) | |
| with gr.Row(): | |
| # fix_seed = gr.Checkbox(value=True, label="Fixed seed") | |
| fix_seed = gr.State(True) | |
| rand_seed = gr.Slider( | |
| minimum=0, maximum=1000, step=1, value=0, label="Seed") | |
| with gr.Row(): | |
| # use_style_cond = gr.Checkbox( | |
| # value=False, label="Enable Style Condition") | |
| # style_cond_image = gr.Image( | |
| # type="pil", label="Style Condition", visible=False, interactive=True) | |
| use_style_cond = gr.State(False) | |
| style_cond_image = gr.State(None) | |
| with gr.Column(scale=3): | |
| gr.HTML( | |
| '<span style="font-size: 20px; font-weight: bold">Generated Image</span>') | |
| # with gr.Row(): | |
| out_gen_1 = gr.Image( | |
| type="pil", visible=True, show_label=False) | |
| gr.HTML( | |
| '<span style="font-size: 20px; font-weight: bold">Step-by-Step Animation</span>') | |
| out_gen_2 = gr.Image( | |
| type="pil", visible=True, show_label=False) | |
| # with gr.Row(): | |
| # out_gen_3 = gr.Image( | |
| # type="pil", visible=False, show_label=False) | |
| # out_gen_4 = gr.Image( | |
| # type="pil", visible=False, show_label=False) | |
| state = gr.State({}) | |
| class Controller: | |
| def __init__(self): | |
| self.calls = 0 | |
| self.tracks = 0 | |
| self.resizes = 0 | |
| self.scales = 0 | |
| def init_white(self, init_white_trigger): | |
| self.calls += 1 | |
| return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger+1 | |
| # def change_n_samples(self, n_samples): | |
| # blank_samples = n_samples % 2 if n_samples > 1 else 0 | |
| # return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \ | |
| # + [gr.Image.update(visible=False) | |
| # for _ in range(4 - n_samples - blank_samples)] | |
| def resize_centercrop(self, state): | |
| self.resizes += 1 | |
| image = state['original_image'].copy() | |
| inpaint_hw = int(0.9 * min(*image.shape[:2])) | |
| state['inpaint_hw'] = inpaint_hw | |
| image_cc = center_crop(image, inpaint_hw) | |
| # print(f'resize triggered {self.resizes}', image.shape, '->', image_cc.shape) | |
| return image_cc, state | |
| def resize_masked(self, state): | |
| self.resizes += 1 | |
| image = state['original_image'].copy() | |
| inpaint_hw = int(0.9 * min(*image.shape[:2])) | |
| state['inpaint_hw'] = inpaint_hw | |
| image_mask = sized_center_mask(image, inpaint_hw, inpaint_hw) | |
| state['masked_image'] = image_mask.copy() | |
| # print(f'mask triggered {self.resizes}') | |
| return image_mask, state | |
| def switch_task_hide_cond(self, task): | |
| cond = False | |
| if task == "Grounded Generation": | |
| cond = True | |
| return gr.Checkbox.update(visible=cond, value=False), gr.Image.update(value=None, visible=False), gr.Slider.update(visible=cond), gr.Checkbox.update(visible=(not cond), value=False) | |
| controller = Controller() | |
| main.load( | |
| lambda x: x+1, | |
| inputs=sketch_pad_trigger, | |
| outputs=sketch_pad_trigger, | |
| queue=False) | |
| sketch_pad.edit( | |
| draw, | |
| inputs=[task, sketch_pad, grounding_instruction, | |
| sketch_pad_resize_trigger, state], | |
| outputs=[out_imagebox, sketch_pad_resize_trigger, | |
| image_scale, state], | |
| queue=False, | |
| ) | |
| grounding_instruction.change( | |
| draw, | |
| inputs=[task, sketch_pad, grounding_instruction, | |
| sketch_pad_resize_trigger, state], | |
| outputs=[out_imagebox, sketch_pad_resize_trigger, | |
| image_scale, state], | |
| queue=False, | |
| ) | |
| clear_btn.click( | |
| clear, | |
| inputs=[task, sketch_pad_trigger, batch_size, state], | |
| outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, | |
| # image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state], | |
| image_scale, out_gen_1, out_gen_2, state], | |
| queue=False) | |
| # task.change( | |
| # partial(clear, switch_task=True), | |
| # inputs=[task, sketch_pad_trigger, batch_size, state], | |
| # outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, | |
| # image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state], | |
| # queue=False) | |
| sketch_pad_trigger.change( | |
| controller.init_white, | |
| inputs=[init_white_trigger], | |
| outputs=[sketch_pad, image_scale, init_white_trigger], | |
| queue=False) | |
| sketch_pad_resize_trigger.change( | |
| controller.resize_masked, | |
| inputs=[state], | |
| outputs=[sketch_pad, state], | |
| queue=False) | |
| # batch_size.change( | |
| # controller.change_n_samples, | |
| # inputs=[batch_size], | |
| # outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4], | |
| # queue=False) | |
| gen_btn.click( | |
| generate, | |
| inputs=[ | |
| task, language_instruction, grounding_instruction, sketch_pad, | |
| alpha_sample, guidance_scale, batch_size, | |
| fix_seed, rand_seed, | |
| use_actual_mask, | |
| append_grounding, style_cond_image, | |
| state, | |
| ], | |
| # outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state], | |
| outputs=[out_gen_1, out_gen_2, state], | |
| queue=True | |
| ) | |
| sketch_pad_resize_trigger.change( | |
| None, | |
| None, | |
| sketch_pad_resize_trigger, | |
| _js=rescale_js, | |
| queue=False) | |
| init_white_trigger.change( | |
| None, | |
| None, | |
| init_white_trigger, | |
| _js=rescale_js, | |
| queue=False) | |
| # use_style_cond.change( | |
| # lambda cond: gr.Image.update(visible=cond), | |
| # use_style_cond, | |
| # style_cond_image, | |
| # queue=False) | |
| # task.change( | |
| # controller.switch_task_hide_cond, | |
| # inputs=task, | |
| # outputs=[use_style_cond, style_cond_image, | |
| # alpha_sample, use_actual_mask], | |
| # queue=False) | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "images/blank.png", | |
| "blue metal cube; red rubber sphere", | |
| ], | |
| [ | |
| "images/blank.png", | |
| "green metal cube; red metal sphere; brown rubber cube", | |
| ], | |
| [ | |
| "images/blank.png", | |
| "blue metal cube; brown rubber sphere; gray metal sphere; yellow rubber cylinder; gray metal cylinder; cyan rubber sphere; green rubber cube; red metal cylinder", | |
| ] | |
| ], | |
| inputs=[ | |
| sketch_pad, | |
| grounding_instruction | |
| ], | |
| outputs=None, | |
| fn=None, | |
| cache_examples=False, | |
| ) | |
| thank_desc = """ | |
| Thanks | |
| <a href="https://huggingface.co/spaces/gligen/demo" target="_blank">GLIGEN demo</a>, for providing bounding box parsing module. | |
| """ | |
| gr.HTML(thank_desc) | |
| main.queue(concurrency_count=1, api_open=False) | |
| main.launch(share=False, show_api=False, show_error=True) | |
| # main.launch( | |
| # server_name="0.0.0.0", | |
| # share=True, | |
| # server_port=7899, | |
| # show_api=False, show_error=True | |
| # ) | |