import torch import json import base64 import io from PIL import Image from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLInpaintPipeline # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("Need to run on GPU") class EndpointHandler: def __init__(self, path="mrcuddle/URPM-Inpaint-Hyper-SDXL"): """Load the SDXL Inpainting model.""" self.pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( path, torch_dtype=torch.float16 ) self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config) self.pipeline = self.pipeline.to(device) def __call__(self, data: dict): """Custom call function for Hugging Face Inference Endpoints.""" try: # Extract inputs from JSON payload inputs = data.get("inputs", "") encoded_image = data.get("image", None) encoded_mask_image = data.get("mask_image", None) # Extract optional parameters with default values num_inference_steps = data.get("num_inference_steps", 25) guidance_scale = data.get("guidance_scale", 7.5) negative_prompt = data.get("negative_prompt", None) height = data.get("height", None) width = data.get("width", None) # Ensure both images are provided if not encoded_image or not encoded_mask_image: raise ValueError("Both 'image' and 'mask_image' are required in base64 format.") # Decode base64 images image = self.decode_base64_image(encoded_image) mask_image = self.decode_base64_image(encoded_mask_image) print("\n--- Running Inference ---") print(f"Prompt: {inputs}") print(f"Steps: {num_inference_steps}, Guidance Scale: {guidance_scale}") print(f"Negative Prompt: {negative_prompt}") print(f"Image Size: {image.size}, Mask Size: {mask_image.size}") # Run inference output_image = self.pipeline( prompt=inputs, image=image, mask_image=mask_image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=1, negative_prompt=negative_prompt, height=height, width=width ).images[0] # Return base64-encoded image return json.dumps({"output": self.encode_base64_image(output_image)}) except Exception as e: return json.dumps({"error": str(e)}) def decode_base64_image(self, image_string): """Decode base64-encoded image to a PIL Image.""" try: base64_image = base64.b64decode(image_string) buffer = io.BytesIO(base64_image) return Image.open(buffer).convert("RGB") except Exception as e: raise ValueError(f"Failed to decode base64 image: {e}") def encode_base64_image(self, image): """Encode PIL image to base64.""" buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") # Create an instance of EndpointHandler handler = EndpointHandler() def handle(data: dict): return handler(data)