Spaces:
Runtime error
Runtime error
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torchvision | |
| import ffmpeg | |
| __all__ = ["joints_dict", "draw_points_and_skeleton"] | |
| def joints_dict(): | |
| joints = { | |
| "coco_25": { | |
| "keypoints": { | |
| 0: "nose", | |
| 1: "left_eye", | |
| 2: "right_eye", | |
| 3: "left_ear", | |
| 4: "right_ear", | |
| 5: "neck", | |
| 6: "left_shoulder", | |
| 7: "right_shoulder", | |
| 8: "left_elbow", | |
| 9: "right_elbow", | |
| 10: "left_wrist", | |
| 11: "right_wrist", | |
| 12: "left_hip", | |
| 13: "right_hip", | |
| 14: "hip", | |
| 15: "left_knee", | |
| 16: "right_knee", | |
| 17: "left_ankle", | |
| 18: "right_ankle", | |
| 19: "left_big toe", | |
| 20: "left_small_toe", | |
| 21: "left_heel", | |
| 22: "right_big_toe", | |
| 23: "right_small_toe", | |
| 24: "right_heel", | |
| }, | |
| "skeleton": [ | |
| [17, 15], [15, 12], [18, 16], [16, 13], [12, 14], [13, 14], [5, 14], | |
| [6, 5], [7, 5], [6, 8], [7, 9], [8, 10], [9, 11], [1, 2], [0, 1], [0, 2], | |
| [1, 3], [2, 4], [17, 21], [18, 24], [19, 20], [22, 23], [19, 21], | |
| [22, 24], [5, 0] | |
| ] | |
| }, | |
| "coco": { | |
| "keypoints": { | |
| 0: "nose", | |
| 1: "left_eye", | |
| 2: "right_eye", | |
| 3: "left_ear", | |
| 4: "right_ear", | |
| 5: "left_shoulder", | |
| 6: "right_shoulder", | |
| 7: "left_elbow", | |
| 8: "right_elbow", | |
| 9: "left_wrist", | |
| 10: "right_wrist", | |
| 11: "left_hip", | |
| 12: "right_hip", | |
| 13: "left_knee", | |
| 14: "right_knee", | |
| 15: "left_ankle", | |
| 16: "right_ankle" | |
| }, | |
| "skeleton": [ | |
| [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], | |
| [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], | |
| [0, 2], [1, 3], [2, 4], [0, 5], [0, 6] | |
| ] | |
| }, | |
| "mpii": { | |
| "keypoints": { | |
| 0: "right_ankle", | |
| 1: "right_knee", | |
| 2: "right_hip", | |
| 3: "left_hip", | |
| 4: "left_knee", | |
| 5: "left_ankle", | |
| 6: "pelvis", | |
| 7: "thorax", | |
| 8: "upper_neck", | |
| 9: "head top", | |
| 10: "right_wrist", | |
| 11: "right_elbow", | |
| 12: "right_shoulder", | |
| 13: "left_shoulder", | |
| 14: "left_elbow", | |
| 15: "left_wrist" | |
| }, | |
| "skeleton": [ | |
| [5, 4], [4, 3], [0, 1], [1, 2], [3, 2], [3, 6], [2, 6], [6, 7], | |
| [7, 8], [8, 9], [13, 7], [12, 7], [13, 14], [12, 11], [14, 15], | |
| [11, 10], | |
| ] | |
| }, | |
| 'ap10k': { | |
| 'keypoints': { | |
| 0: 'L_Eye', | |
| 1: 'R_Eye', | |
| 2: 'Nose', | |
| 3: 'Neck', | |
| 4: 'Root of tail', | |
| 5: 'L_Shoulder', | |
| 6: 'L_Elbow', | |
| 7: 'L_F_Paw', | |
| 8: 'R_Shoulder', | |
| 9: 'R_Elbow', | |
| 10: 'R_F_Paw', | |
| 11: 'L_Hip', | |
| 12: 'L_Knee', | |
| 13: 'L_B_Paw', | |
| 14: 'R_Hip', | |
| 15: 'R_Knee', | |
| 16: 'R_B_Paw' | |
| }, | |
| 'skeleton': [ | |
| [0, 1], [0, 2], [1, 2], [2, 3], [3, 4], [3, 5], [5, 6], [6, 7], | |
| [3, 8], [8, 9], [9, 10], [4, 11], [11, 12], [12, 13], [4, 14], | |
| [14, 15], [15, 16] | |
| ] | |
| }, | |
| 'apt36k': { | |
| 'keypoints': { | |
| 0: 'L_Eye', | |
| 1: 'R_Eye', | |
| 2: 'Nose', | |
| 3: 'Neck', | |
| 4: 'Root of tail', | |
| 5: 'L_Shoulder', | |
| 6: 'L_Elbow', | |
| 7: 'L_F_Paw', | |
| 8: 'R_Shoulder', | |
| 9: 'R_Elbow', | |
| 10: 'R_F_Paw', | |
| 11: 'L_Hip', | |
| 12: 'L_Knee', | |
| 13: 'L_B_Paw', | |
| 14: 'R_Hip', | |
| 15: 'R_Knee', | |
| 16: 'R_B_Paw' | |
| }, | |
| 'skeleton': [ | |
| [0, 1], [0, 2], [1, 2], [2, 3], [3, 4], [3, 5], [5, 6], [6, 7], | |
| [3, 8], [8, 9], [9, 10], [4, 11], [11, 12], [12, 13], [4, 14], | |
| [14, 15], [15, 16] | |
| ] | |
| }, | |
| 'aic': { | |
| 'keypoints': { | |
| 0: 'right_shoulder', | |
| 1: 'right_elbow', | |
| 2: 'right_wrist', | |
| 3: 'left_shoulder', | |
| 4: 'left_elbow', | |
| 5: 'left_wrist', | |
| 6: 'right_hip', | |
| 7: 'right_knee', | |
| 8: 'right_ankle', | |
| 9: 'left_hip', | |
| 10: 'left_knee', | |
| 11: 'left_ankle', | |
| 12: 'head_top', | |
| 13: 'neck' | |
| }, | |
| 'skeleton': [ | |
| [2, 1], [1, 0], [0, 13], [13, 3], [3, 4], [4, 5], [8, 7], | |
| [7, 6], [6, 9], [9, 10], [10, 11], [12, 13], [0, 6], [3, 9] | |
| ] | |
| }, | |
| 'wholebody': { | |
| 'keypoints': { | |
| 0: 'nose', | |
| 1: 'left_eye', | |
| 2: 'right_eye', | |
| 3: 'left_ear', | |
| 4: 'right_ear', | |
| 5: 'left_shoulder', | |
| 6: 'right_shoulder', | |
| 7: 'left_elbow', | |
| 8: 'right_elbow', | |
| 9: 'left_wrist', | |
| 10: 'right_wrist', | |
| 11: 'left_hip', | |
| 12: 'right_hip', | |
| 13: 'left_knee', | |
| 14: 'right_knee', | |
| 15: 'left_ankle', | |
| 16: 'right_ankle', | |
| 17: 'left_big_toe', | |
| 18: 'left_small_toe', | |
| 19: 'left_heel', | |
| 20: 'right_big_toe', | |
| 21: 'right_small_toe', | |
| 22: 'right_heel', | |
| 23: 'face-0', | |
| 24: 'face-1', | |
| 25: 'face-2', | |
| 26: 'face-3', | |
| 27: 'face-4', | |
| 28: 'face-5', | |
| 29: 'face-6', | |
| 30: 'face-7', | |
| 31: 'face-8', | |
| 32: 'face-9', | |
| 33: 'face-10', | |
| 34: 'face-11', | |
| 35: 'face-12', | |
| 36: 'face-13', | |
| 37: 'face-14', | |
| 38: 'face-15', | |
| 39: 'face-16', | |
| 40: 'face-17', | |
| 41: 'face-18', | |
| 42: 'face-19', | |
| 43: 'face-20', | |
| 44: 'face-21', | |
| 45: 'face-22', | |
| 46: 'face-23', | |
| 47: 'face-24', | |
| 48: 'face-25', | |
| 49: 'face-26', | |
| 50: 'face-27', | |
| 51: 'face-28', | |
| 52: 'face-29', | |
| 53: 'face-30', | |
| 54: 'face-31', | |
| 55: 'face-32', | |
| 56: 'face-33', | |
| 57: 'face-34', | |
| 58: 'face-35', | |
| 59: 'face-36', | |
| 60: 'face-37', | |
| 61: 'face-38', | |
| 62: 'face-39', | |
| 63: 'face-40', | |
| 64: 'face-41', | |
| 65: 'face-42', | |
| 66: 'face-43', | |
| 67: 'face-44', | |
| 68: 'face-45', | |
| 69: 'face-46', | |
| 70: 'face-47', | |
| 71: 'face-48', | |
| 72: 'face-49', | |
| 73: 'face-50', | |
| 74: 'face-51', | |
| 75: 'face-52', | |
| 76: 'face-53', | |
| 77: 'face-54', | |
| 78: 'face-55', | |
| 79: 'face-56', | |
| 80: 'face-57', | |
| 81: 'face-58', | |
| 82: 'face-59', | |
| 83: 'face-60', | |
| 84: 'face-61', | |
| 85: 'face-62', | |
| 86: 'face-63', | |
| 87: 'face-64', | |
| 88: 'face-65', | |
| 89: 'face-66', | |
| 90: 'face-67', | |
| 91: 'left_hand_root', | |
| 92: 'left_thumb1', | |
| 93: 'left_thumb2', | |
| 94: 'left_thumb3', | |
| 95: 'left_thumb4', | |
| 96: 'left_forefinger1', | |
| 97: 'left_forefinger2', | |
| 98: 'left_forefinger3', | |
| 99: 'left_forefinger4', | |
| 100: 'left_middle_finger1', | |
| 101: 'left_middle_finger2', | |
| 102: 'left_middle_finger3', | |
| 103: 'left_middle_finger4', | |
| 104: 'left_ring_finger1', | |
| 105: 'left_ring_finger2', | |
| 106: 'left_ring_finger3', | |
| 107: 'left_ring_finger4', | |
| 108: 'left_pinky_finger1', | |
| 109: 'left_pinky_finger2', | |
| 110: 'left_pinky_finger3', | |
| 111: 'left_pinky_finger4', | |
| 112: 'right_hand_root', | |
| 113: 'right_thumb1', | |
| 114: 'right_thumb2', | |
| 115: 'right_thumb3', | |
| 116: 'right_thumb4', | |
| 117: 'right_forefinger1', | |
| 118: 'right_forefinger2', | |
| 119: 'right_forefinger3', | |
| 120: 'right_forefinger4', | |
| 121: 'right_middle_finger1', | |
| 122: 'right_middle_finger2', | |
| 123: 'right_middle_finger3', | |
| 124: 'right_middle_finger4', | |
| 125: 'right_ring_finger1', | |
| 126: 'right_ring_finger2', | |
| 127: 'right_ring_finger3', | |
| 128: 'right_ring_finger4', | |
| 129: 'right_pinky_finger1', | |
| 130: 'right_pinky_finger2', | |
| 131: 'right_pinky_finger3', | |
| 132: 'right_pinky_finger4' | |
| }, | |
| 'skeleton': [ | |
| [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], | |
| [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], | |
| [1, 3], [2, 4], [3, 5], [4, 6], [15, 17], [15, 18], [15, 19], | |
| [16, 20], [16, 21], [16, 22], [91, 92], [92, 93], [93, 94], | |
| [94, 95], [91, 96], [96, 97], [97, 98], [98, 99], [91, 100], | |
| [100, 101], [101, 102], [102, 103], [91, 104], [104, 105], | |
| [105, 106], [106, 107], [91, 108], [108, 109], [109, 110], | |
| [110, 111], [112, 113], [113, 114], [114, 115], [115, 116], | |
| [112, 117], [117, 118], [118, 119], [119, 120], [112, 121], | |
| [121, 122], [122, 123], [123, 124], [112, 125], [125, 126], | |
| [126, 127], [127, 128], [112, 129], [129, 130], [130, 131], | |
| [131, 132] | |
| ] | |
| } | |
| } | |
| return joints | |
| def draw_points(image, points, color_palette='tab20', palette_samples=16, confidence_threshold=0.5): | |
| """ | |
| Draws `points` on `image`. | |
| Args: | |
| image: image in opencv format | |
| points: list of points to be drawn. | |
| Shape: (nof_points, 3) | |
| Format: each point should contain (y, x, confidence) | |
| color_palette: name of a matplotlib color palette | |
| Default: 'tab20' | |
| palette_samples: number of different colors sampled from the `color_palette` | |
| Default: 16 | |
| confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1] | |
| Default: 0.5 | |
| Returns: | |
| A new image with overlaid points | |
| """ | |
| try: | |
| colors = np.round( | |
| np.array(plt.get_cmap(color_palette).colors) * 255 | |
| ).astype(np.uint8)[:, ::-1].tolist() | |
| except AttributeError: # if palette has not pre-defined colors | |
| colors = np.round( | |
| np.array(plt.get_cmap(color_palette)(np.linspace(0, 1, palette_samples))) * 255 | |
| ).astype(np.uint8)[:, -2::-1].tolist() | |
| circle_size = max(1, min(image.shape[:2]) // 150) # ToDo Shape it taking into account the size of the detection | |
| # circle_size = max(2, int(np.sqrt(np.max(np.max(points, axis=0) - np.min(points, axis=0)) // 16))) | |
| for i, pt in enumerate(points): | |
| if pt[2] > confidence_threshold: | |
| image = cv2.circle(image, (int(pt[1]), int(pt[0])), circle_size, tuple(colors[i % len(colors)]), -1) | |
| return image | |
| def draw_skeleton(image, points, skeleton, color_palette='Set2', palette_samples=8, person_index=0, | |
| confidence_threshold=0.5): | |
| """ | |
| Draws a `skeleton` on `image`. | |
| Args: | |
| image: image in opencv format | |
| points: list of points to be drawn. | |
| Shape: (nof_points, 3) | |
| Format: each point should contain (y, x, confidence) | |
| skeleton: list of joints to be drawn | |
| Shape: (nof_joints, 2) | |
| Format: each joint should contain (point_a, point_b) where `point_a` and `point_b` are an index in `points` | |
| color_palette: name of a matplotlib color palette | |
| Default: 'Set2' | |
| palette_samples: number of different colors sampled from the `color_palette` | |
| Default: 8 | |
| person_index: index of the person in `image` | |
| Default: 0 | |
| confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1] | |
| Default: 0.5 | |
| Returns: | |
| A new image with overlaid joints | |
| """ | |
| try: | |
| colors = np.round( | |
| np.array(plt.get_cmap(color_palette).colors) * 255 | |
| ).astype(np.uint8)[:, ::-1].tolist() | |
| except AttributeError: # if palette has not pre-defined colors | |
| colors = np.round( | |
| np.array(plt.get_cmap(color_palette)(np.linspace(0, 1, palette_samples))) * 255 | |
| ).astype(np.uint8)[:, -2::-1].tolist() | |
| for i, joint in enumerate(skeleton): | |
| pt1, pt2 = points[joint] | |
| if pt1[2] > confidence_threshold and pt2[2] > confidence_threshold: | |
| image = cv2.line( | |
| image, (int(pt1[1]), int(pt1[0])), (int(pt2[1]), int(pt2[0])), | |
| tuple(colors[person_index % len(colors)]), 2 | |
| ) | |
| return image | |
| def draw_points_and_skeleton(image, points, skeleton, points_color_palette='tab20', points_palette_samples=16, | |
| skeleton_color_palette='Set2', skeleton_palette_samples=8, person_index=0, | |
| confidence_threshold=0.5): | |
| """ | |
| Draws `points` and `skeleton` on `image`. | |
| Args: | |
| image: image in opencv format | |
| points: list of points to be drawn. | |
| Shape: (nof_points, 3) | |
| Format: each point should contain (y, x, confidence) | |
| skeleton: list of joints to be drawn | |
| Shape: (nof_joints, 2) | |
| Format: each joint should contain (point_a, point_b) where `point_a` and `point_b` are an index in `points` | |
| points_color_palette: name of a matplotlib color palette | |
| Default: 'tab20' | |
| points_palette_samples: number of different colors sampled from the `color_palette` | |
| Default: 16 | |
| skeleton_color_palette: name of a matplotlib color palette | |
| Default: 'Set2' | |
| skeleton_palette_samples: number of different colors sampled from the `color_palette` | |
| Default: 8 | |
| person_index: index of the person in `image` | |
| Default: 0 | |
| confidence_threshold: only points with a confidence higher than this threshold will be drawn. Range: [0, 1] | |
| Default: 0.5 | |
| Returns: | |
| A new image with overlaid joints | |
| """ | |
| image = draw_skeleton(image, points, skeleton, color_palette=skeleton_color_palette, | |
| palette_samples=skeleton_palette_samples, person_index=person_index, | |
| confidence_threshold=confidence_threshold) | |
| image = draw_points(image, points, color_palette=points_color_palette, palette_samples=points_palette_samples, | |
| confidence_threshold=confidence_threshold) | |
| return image | |
| def save_images(images, target, joint_target, output, joint_output, joint_visibility, summary_writer=None, step=0, | |
| prefix=''): | |
| """ | |
| Creates a grid of images with gt joints and a grid with predicted joints. | |
| This is a basic function for debugging purposes only. | |
| If summary_writer is not None, the grid will be written in that SummaryWriter with name "{prefix}_images" and | |
| "{prefix}_predictions". | |
| Args: | |
| images (torch.Tensor): a tensor of images with shape (batch x channels x height x width). | |
| target (torch.Tensor): a tensor of gt heatmaps with shape (batch x channels x height x width). | |
| joint_target (torch.Tensor): a tensor of gt joints with shape (batch x joints x 2). | |
| output (torch.Tensor): a tensor of predicted heatmaps with shape (batch x channels x height x width). | |
| joint_output (torch.Tensor): a tensor of predicted joints with shape (batch x joints x 2). | |
| joint_visibility (torch.Tensor): a tensor of joint visibility with shape (batch x joints). | |
| summary_writer (tb.SummaryWriter): a SummaryWriter where write the grids. | |
| Default: None | |
| step (int): summary_writer step. | |
| Default: 0 | |
| prefix (str): summary_writer name prefix. | |
| Default: "" | |
| Returns: | |
| A pair of images which are built from torchvision.utils.make_grid | |
| """ | |
| # Input images with gt | |
| images_ok = images.detach().clone() | |
| images_ok[:, 0].mul_(0.229).add_(0.485) | |
| images_ok[:, 1].mul_(0.224).add_(0.456) | |
| images_ok[:, 2].mul_(0.225).add_(0.406) | |
| for i in range(images.shape[0]): | |
| joints = joint_target[i] * 4. | |
| joints_vis = joint_visibility[i] | |
| for joint, joint_vis in zip(joints, joints_vis): | |
| if joint_vis[0]: | |
| a = int(joint[1].item()) | |
| b = int(joint[0].item()) | |
| # images_ok[i][:, a-1:a+1, b-1:b+1] = torch.tensor([1, 0, 0]) | |
| images_ok[i][0, a - 1:a + 1, b - 1:b + 1] = 1 | |
| images_ok[i][1:, a - 1:a + 1, b - 1:b + 1] = 0 | |
| grid_gt = torchvision.utils.make_grid(images_ok, nrow=int(images_ok.shape[0] ** 0.5), padding=2, normalize=False) | |
| if summary_writer is not None: | |
| summary_writer.add_image(prefix + 'images', grid_gt, global_step=step) | |
| # Input images with prediction | |
| images_ok = images.detach().clone() | |
| images_ok[:, 0].mul_(0.229).add_(0.485) | |
| images_ok[:, 1].mul_(0.224).add_(0.456) | |
| images_ok[:, 2].mul_(0.225).add_(0.406) | |
| for i in range(images.shape[0]): | |
| joints = joint_output[i] * 4. | |
| joints_vis = joint_visibility[i] | |
| for joint, joint_vis in zip(joints, joints_vis): | |
| if joint_vis[0]: | |
| a = int(joint[1].item()) | |
| b = int(joint[0].item()) | |
| # images_ok[i][:, a-1:a+1, b-1:b+1] = torch.tensor([1, 0, 0]) | |
| images_ok[i][0, a - 1:a + 1, b - 1:b + 1] = 1 | |
| images_ok[i][1:, a - 1:a + 1, b - 1:b + 1] = 0 | |
| grid_pred = torchvision.utils.make_grid(images_ok, nrow=int(images_ok.shape[0] ** 0.5), padding=2, normalize=False) | |
| if summary_writer is not None: | |
| summary_writer.add_image(prefix + 'predictions', grid_pred, global_step=step) | |
| # Heatmaps | |
| # ToDo | |
| # for h in range(0,17): | |
| # heatmap = torchvision.utils.make_grid(output[h].detach(), nrow=int(np.sqrt(output.shape[0])), | |
| # padding=2, normalize=True, range=(0, 1)) | |
| # summary_writer.add_image('train_heatmap_%d' % h, heatmap, global_step=step + epoch*len_dl_train) | |
| return grid_gt, grid_pred | |
| def check_video_rotation(filename): | |
| # thanks to | |
| # https://stackoverflow.com/questions/53097092/frame-from-video-is-upside-down-after-extracting/55747773#55747773 | |
| # this returns meta-data of the video file in form of a dictionary | |
| meta_dict = ffmpeg.probe(filename) | |
| # from the dictionary, meta_dict['streams'][0]['tags']['rotate'] is the key | |
| # we are looking for | |
| rotation_code = None | |
| try: | |
| if int(meta_dict['streams'][0]['tags']['rotate']) == 90: | |
| rotation_code = cv2.ROTATE_90_CLOCKWISE | |
| elif int(meta_dict['streams'][0]['tags']['rotate']) == 180: | |
| rotation_code = cv2.ROTATE_180 | |
| elif int(meta_dict['streams'][0]['tags']['rotate']) == 270: | |
| rotation_code = cv2.ROTATE_90_COUNTERCLOCKWISE | |
| else: | |
| raise ValueError | |
| except KeyError: | |
| pass | |
| return rotation_code | |