import gradio as gr import torch import torch.nn.functional as F import torchvision.transforms as T import json from CNN import CNN # def greet(name): # return "Hello " + name + "!!" # demo = gr.Interface(fn=greet, inputs="text", outputs="text") # demo.launch() # Load the model n_classes = 345 params = { 'n_filters': 30, 'hidden_dim': 100, 'n_layers': 2, 'n_classes': n_classes } print('testesesesf') model = CNN(**params) model.load_state_dict(torch.load('model_weights.pth', map_location='cpu')) model.eval() # utils labels_path = 'labels.json' with open(labels_path, 'r') as f: names = json.load(f) transform = T.Compose([ T.ToTensor(), # (1, H, W), values in [0, 1], white=1 black=0 T.Lambda(lambda x: 1.0 - x), # invert -> white=0, black=1 T.Resize((28, 28), interpolation=T.InterpolationMode.BILINEAR), # T.Normalize((0.5,), (0.5,)) # optional if your model expects [-1, 1] ]) def predict(input_image): img = input_image['composite'] if img is None: return {"No drawing detected": 1.0} img = transform(img) img = img.unsqueeze(0).to(torch.float32) # add batch dimension # torch.save(img, ) with torch.no_grad(): out = model(img) # idx = torch.argmax(out).item() probs = F.softmax(out, dim=1).squeeze(0) res = {names[i]:proba.item() for i, proba in enumerate(probs)} return res demo = gr.Interface( fn=predict, inputs=gr.Sketchpad( label="Draw a sketch", image_mode='L', brush=gr.Brush(default_size=15, default_color='black', colors=['black'], color_mode='fixed') ), outputs=gr.Label(num_top_classes=5), title="Sketch Recognition model", clear_btn=gr.ClearButton(), live=True ) print('test') demo.launch()