enpaiva commited on
Commit
b80b7af
·
verified ·
1 Parent(s): ed3111a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -18
app.py CHANGED
@@ -30,26 +30,40 @@ id_to_names = {
30
  }
31
 
32
  def recognize_image(input_img, conf_threshold, iou_threshold):
33
- det_res = model.predict(
34
- input_img,
35
- imgsz=1024,
36
- conf=conf_threshold,
37
- device=device,
38
- )[0]
39
- boxes = det_res.__dict__['boxes'].xyxy
40
- classes = det_res.__dict__['boxes'].cls
41
- scores = det_res.__dict__['boxes'].conf
 
42
 
43
- indices = torchvision.ops.nms(boxes=torch.Tensor(boxes), scores=torch.Tensor(scores),iou_threshold=iou_threshold)
44
- boxes, scores, classes = boxes[indices], scores[indices], classes[indices]
45
- if len(boxes.shape) == 1:
46
- boxes = np.expand_dims(boxes, 0)
47
- scores = np.expand_dims(scores, 0)
48
- classes = np.expand_dims(classes, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- vis_result = visualize_bbox(input_img, boxes, classes, scores, id_to_names)
51
- return vis_result
52
-
53
  def gradio_reset():
54
  return gr.update(value=None), gr.update(value=None)
55
 
 
30
  }
31
 
32
  def recognize_image(input_img, conf_threshold, iou_threshold):
33
+ try:
34
+ det_res = model.predict(
35
+ input_img,
36
+ imgsz=1024,
37
+ conf=conf_threshold,
38
+ device=device,
39
+ )[0]
40
+ boxes = det_res.__dict__['boxes'].xyxy
41
+ classes = det_res.__dict__['boxes'].cls
42
+ scores = det_res.__dict__['boxes'].conf
43
 
44
+ indices = torchvision.ops.nms(
45
+ boxes=torch.Tensor(boxes),
46
+ scores=torch.Tensor(scores),
47
+ iou_threshold=iou_threshold
48
+ )
49
+
50
+ boxes, scores, classes = boxes[indices], scores[indices], classes[indices]
51
+
52
+ if len(boxes.shape) == 1:
53
+ boxes = np.expand_dims(boxes, 0)
54
+ scores = np.expand_dims(scores, 0)
55
+ classes = np.expand_dims(classes, 0)
56
+
57
+ output = visualize_bbox(input_img, boxes, classes, scores, id_to_names)
58
+ if not isinstance(output, (np.ndarray, Image.Image)):
59
+ raise ValueError("Output is not a valid image")
60
+
61
+ return output
62
+ except Exception as e:
63
+ print(f"[ERROR] recognize_image failed: {e}")
64
+ # Return blank image or raise if debugging
65
+ return np.zeros((512, 512, 3), dtype=np.uint8)
66
 
 
 
 
67
  def gradio_reset():
68
  return gr.update(value=None), gr.update(value=None)
69