DocLayout-YOLO / visualization.py
enpaiva's picture
Update visualization.py
afd0c7e verified
import numpy as np
import cv2
from PIL import Image
def colormap(N=256, normalized=False):
"""
Generate the color map.
Args:
N (int): Number of labels (default is 256).
normalized (bool): If True, return colors normalized to [0, 1]. Otherwise, return [0, 255].
Returns:
np.ndarray: Color map array of shape (N, 3).
"""
def bitget(byteval, idx):
"""
Get the bit value at the specified index.
Args:
byteval (int): The byte value.
idx (int): The index of the bit.
Returns:
int: The bit value (0 or 1).
"""
return ((byteval & (1 << idx)) != 0)
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << (7 - j))
g = g | (bitget(c, 1) << (7 - j))
b = b | (bitget(c, 2) << (7 - j))
c = c >> 3
cmap[i] = np.array([r, g, b])
if normalized:
cmap = cmap.astype(np.float32) / 255.0
return cmap
def visualize_bbox(image_path, bboxes, classes, scores, id_to_names, alpha=0.3):
if isinstance(image_path, Image.Image) or isinstance(image_path, np.ndarray):
image = np.array(image_path)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
else:
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Could not load image from path: {image_path}")
overlay = image.copy()
cmap = colormap(N=len(id_to_names), normalized=False)
if len(bboxes) == 0:
print("No bounding boxes to display.")
return image # Return original image if nothing detected
for i in range(len(bboxes)):
try:
x_min, y_min, x_max, y_max = map(int, bboxes[i])
class_id = int(classes[i])
class_name = id_to_names.get(class_id, f"unknown_{class_id}")
score = scores[i]
text = f"{class_name}:{score:.3f}"
color = tuple(int(c) for c in cmap[class_id % len(cmap)])
cv2.rectangle(overlay, (x_min, y_min), (x_max, y_max), color, -1)
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2)
(text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.9, 2)
cv2.rectangle(image, (x_min, y_min - text_height - baseline), (x_min + text_width, y_min), color, -1)
cv2.putText(image, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 255), 2)
except Exception as e:
print(f"Skipping box {i} due to error: {e}")
cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
return image