ProbMED / mixinhelpers.py
mcintoc's picture
add documentation and model weights (#1)
b5350db verified
# For CXR
import random
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from transformers import BatchEncoding, PreTrainedTokenizer
"""
Mixin for all modalities, each mixin has:
- preprocess function that takes in path or data and returns tensor
- construct_input function that takes in tensor and returns dict with batch
dimension for model input
- key string for model input dict
"""
class ECHO_Mixin:
LOWER_YELLOW: list[int] = [20, 50, 50]
UPPER_YELLOW: list[int] = [100, 255, 255]
IMAGE_SIZE: tuple[int, int] = (224, 224)
NORM_MEAN: tuple[float, float, float] = (0.48145466, 0.4578275, 0.40821073)
NORM_STD: tuple[float, float, float] = (0.26862954, 0.26130258, 0.27577711)
ECHO_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(), # Scaling into [0, 1]
transforms.Resize(IMAGE_SIZE),
transforms.Normalize(
mean=NORM_MEAN,
std=NORM_STD,
),
]
)
ECHO_KEY: str = "echo"
def grabimage(self, split: str, data: dict[str, np.ndarray]) -> np.ndarray:
""""""
if split == "train":
caseofinterest = random.choice(list(data.keys()))
imageindice = random.choice(list(range(data[caseofinterest].shape[0])))
else:
caseofinterest = random.choice(list(data.keys())) # listofcases[0]
imageindice = 0
video = data[caseofinterest]
return self.extract_echoframe(imageindice, video)
def extract_echoframe(self, imageindice: int, video: np.ndarray) -> np.ndarray:
image = video[imageindice]
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
lower_yellow = np.array(self.LOWER_YELLOW) # Lower bound of yellow hue
upper_yellow = np.array(self.UPPER_YELLOW) # Upper bound of yellow hue
mask = cv2.inRange(hsv_image, lower_yellow, upper_yellow)
image[mask > 0] = [0, 0, 0]
image = np.array(image, dtype=np.float32)
image -= image.min()
image /= image.max()
image *= 255
image = image
image = image[:, :, :]
image = image.astype(np.uint8)
return image
def preprocess_echoseries(
self, video_dict: dict[str, np.ndarray], split: str = "valid"
) -> torch.Tensor:
"""assumes inference mode"""
image = self.grabimage(split, video_dict)
if not isinstance(image, np.ndarray):
raise TypeError("Expected image to be a numpy ndarray")
pil_image = Image.fromarray(image)
transformed = self.ECHO_TRANSFORMS(pil_image)
if not isinstance(transformed, torch.Tensor):
transformed = transforms.ToTensor()(pil_image)
return transformed
def preprocess_single_echo(self, avi_path: str) -> torch.Tensor:
"""assumes inference mode, opens AVI file and processes first frame
Output: image: torch.Tensor of shape (C, H, W)
"""
cap = cv2.VideoCapture(avi_path)
success, frame = cap.read()
cap.release()
if not success or frame is None:
raise ValueError(f"Could not read frame from AVI file: {avi_path}")
image = self.extract_echoframe(0, np.array([frame])) # process first frame
image = self.ECHO_TRANSFORMS(Image.fromarray(image))
if not isinstance(image, torch.Tensor):
image = torch.from_numpy(image)
return image
# CXR
class CXR_Mixin:
RESIZE: tuple[int, int] = (256, 256)
IMAGE_SIZE: tuple[int, int] = (224, 224)
NORM_MEAN: list[float] = [0.5862785803043838]
NORM_STD: list[float] = [0.27950088968644304]
VISION_KEY: str = "vision"
CXR_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(), # Scaling into [0, 1]
transforms.Resize(RESIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.Normalize(
mean=NORM_MEAN,
std=NORM_STD,
),
]
)
@staticmethod
def remove_border(pixel_array: np.ndarray) -> np.ndarray:
# Find where the image is not just background (0s)
coords = np.column_stack(np.where(pixel_array > 0))
x_min, y_min = coords.min(axis=0)
x_max, y_max = coords.max(axis=0)
# Crop the image
cropped_image = pixel_array[x_min:x_max, y_min:y_max]
return cropped_image
def preprocess_loaded_cxr(self, img: np.array) -> torch.Tensor:
cxr = self.remove_border(img)
# Convert grayscale image to 3-channel RGB
cxr = np.repeat(cxr[..., np.newaxis], 3, axis=-1)
cxr = Image.fromarray(cxr)
transformed = self.CXR_TRANSFORMS(cxr)
if not isinstance(transformed, torch.Tensor):
transformed = transforms.ToTensor()(cxr)
return transformed
def preprocess_single_cxr(self, image_path: str) -> torch.Tensor:
"""assumes inference mode"""
with open(image_path, "rb") as fopen:
image = Image.open(fopen).convert("RGB")
image = np.array(image)[:, :, 0] # convert to grayscale
cxr = self.preprocess_loaded_cxr(image)
return cxr
class ECG_Mixin:
LENGTH: int = 1000
FREQUENCY: int = 100 # we assume 100Hz sampling rate
CHANNELS: int = 12
NORM_MEAN: float = 0.02547506
NORM_SCALE: float = 0.16486814
NORM_VAR: float = 0.0271815
ECG_KEY: str = "ecg"
def manual_standardize(self, x: np.ndarray) -> torch.Tensor:
"""
Apply manual standardization to ECG or other data.
Equivalent to sklearn's StandardScaler with given constants.
Args:
x (np.ndarray): Input array of shape (12, 1000)
Returns:
torch.Tensor: Scaled array of the same shape
"""
return torch.from_numpy((x - self.NORM_MEAN) / self.NORM_SCALE).float()
def check_ecg(self, ecg: np.ndarray) -> np.ndarray:
# Find where the image is not just background (0s)
if np.isnan(ecg).any() or np.isinf(ecg).any():
raise ValueError("ECG contains NaN or Inf values")
return ecg[:, : self.LENGTH] # Truncate to first 1000 length (10 seconds at 100Hz)
def preprocess_single_ecg(self, ecg_path: str) -> torch.Tensor:
"""assumes inference mode"""
# ecg is a np array path, assumes 12 channels
ecg = np.load(ecg_path)
if ecg.ndim == 2 and ecg.shape[0] != self.CHANNELS:
raise ValueError(f"Expected ECG with {self.CHANNELS} channels, got {ecg.shape[0]}")
ecg = self.check_ecg(ecg)
transformed = self.manual_standardize(ecg)
return transformed
class Text_Mixin:
MODALITY_LIST: dict[str, str] = {"echo": "echocardiogram", "ecg": "ecg", "vision": "cxr"}
MAX_LENGTH: int = 120 # longer length to accomodate longer reports
TEXT_LENGTH: int = 100 # 100 words
def get_first_n_words(self, text: str, n: int = 100) -> str:
"""97.5 percentile of text is less than 35 words"""
words = text.split() # Split the text into words
return " ".join(words[:n]) # Join the first n words back into a string
def createCaption(self, caption: str, modality: str = "") -> str:
assert modality in set(self.MODALITY_LIST.keys()) or modality == "", (
f"modality should be in {self.MODALITY_LIST} or empty"
)
return f"text : {caption}, {modality} looks like : "
def createTokenizedCaption(self, caption: str, tokenizer: PreTrainedTokenizer) -> BatchEncoding:
encoding = tokenizer(
caption,
padding="max_length",
truncation=True,
max_length=self.MAX_LENGTH,
return_tensors="pt",
)
return encoding
def construct_caption(
self, caption: str, tokenizer: PreTrainedTokenizer, modality: str = ""
) -> BatchEncoding:
"""given caption string, return tokenized caption dict for model input
Output: dict with keys 'input_ids' and 'attention_mask', each of shape (1, L)
"""
caption_str = self.createCaption(caption, modality)
tokenized = self.createTokenizedCaption(caption_str, tokenizer)
return tokenized