import gradio as gr import torch import timm from PIL import Image model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True).eval() data_config = timm.data.resolve_model_data_config(model) transforms = timm.data.create_transform(**data_config, is_training=False) class_names = model.pretrained_cfg["label_names"] @torch.inference_mode() def predict(image: Image.Image): tensor = transforms(image).unsqueeze(0) probs = model(tensor).softmax(dim=-1).cpu().flatten() top_id = int(probs.argmax()) top_label = class_names[top_id] probs_dict = {class_names[i]: float(p) for i, p in enumerate(probs)} return top_label, probs_dict demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[ gr.Label(label="Top prediction"), gr.Label(label="All probabilities", num_top_classes=len(class_names)), ], title="NSFW Image Detection", description="Drag & drop an image to see the predicted class", ) if __name__ == "__main__": demo.launch(error=true)