tempoPFN / src /data /scalers.py
Vladyslav Moroshan
Apply ruff formatting
96e1a32
from abc import ABC, abstractmethod
import torch
class BaseScaler(ABC):
"""
Abstract base class for time series scalers.
Defines the interface for scaling multivariate time series data with support
for masked values and channel-wise scaling.
"""
@abstractmethod
def compute_statistics(
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
) -> dict[str, torch.Tensor]:
"""
Compute scaling statistics from historical data.
"""
pass
@abstractmethod
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply scaling transformation to data.
"""
pass
@abstractmethod
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply inverse scaling transformation to recover original scale.
"""
pass
class RobustScaler(BaseScaler):
"""
Robust scaler using median and IQR for normalization.
"""
def __init__(self, epsilon: float = 1e-6, min_scale: float = 1e-3):
if epsilon <= 0:
raise ValueError("epsilon must be positive")
if min_scale <= 0:
raise ValueError("min_scale must be positive")
self.epsilon = epsilon
self.min_scale = min_scale
def compute_statistics(
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
) -> dict[str, torch.Tensor]:
"""
Compute median and IQR statistics from historical data with improved numerical stability.
"""
batch_size, seq_len, num_channels = history_values.shape
device = history_values.device
medians = torch.zeros(batch_size, 1, num_channels, device=device)
iqrs = torch.ones(batch_size, 1, num_channels, device=device)
for b in range(batch_size):
for c in range(num_channels):
channel_data = history_values[b, :, c]
if history_mask is not None:
mask = history_mask[b, :].bool()
valid_data = channel_data[mask]
else:
valid_data = channel_data
if len(valid_data) == 0:
continue
valid_data = valid_data[torch.isfinite(valid_data)]
if len(valid_data) == 0:
continue
median_val = torch.median(valid_data)
medians[b, 0, c] = median_val
if len(valid_data) > 1:
try:
q75 = torch.quantile(valid_data, 0.75)
q25 = torch.quantile(valid_data, 0.25)
iqr_val = q75 - q25
iqr_val = torch.max(iqr_val, torch.tensor(self.min_scale, device=device))
iqrs[b, 0, c] = iqr_val
except Exception:
std_val = torch.std(valid_data)
iqrs[b, 0, c] = torch.max(std_val, torch.tensor(self.min_scale, device=device))
else:
iqrs[b, 0, c] = self.min_scale
return {"median": medians, "iqr": iqrs}
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply robust scaling: (data - median) / (iqr + epsilon).
"""
median = statistics["median"]
iqr = statistics["iqr"]
denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device))
scaled_data = (data - median) / denominator
scaled_data = torch.clamp(scaled_data, -50.0, 50.0)
return scaled_data
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply inverse robust scaling, now compatible with 3D or 4D tensors.
"""
median = statistics["median"]
iqr = statistics["iqr"]
denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device))
if scaled_data.ndim == 4:
denominator = denominator.unsqueeze(-1)
median = median.unsqueeze(-1)
return scaled_data * denominator + median
class MinMaxScaler(BaseScaler):
"""
Min-Max scaler that normalizes data to the range [-1, 1].
"""
def __init__(self, epsilon: float = 1e-8):
if epsilon <= 0:
raise ValueError("epsilon must be positive")
self.epsilon = epsilon
def compute_statistics(
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
) -> dict[str, torch.Tensor]:
"""
Compute min and max statistics from historical data.
"""
batch_size, seq_len, num_channels = history_values.shape
device = history_values.device
mins = torch.zeros(batch_size, 1, num_channels, device=device)
maxs = torch.ones(batch_size, 1, num_channels, device=device)
for b in range(batch_size):
for c in range(num_channels):
channel_data = history_values[b, :, c]
if history_mask is not None:
mask = history_mask[b, :].bool()
valid_data = channel_data[mask]
else:
valid_data = channel_data
if len(valid_data) == 0:
continue
min_val = torch.min(valid_data)
max_val = torch.max(valid_data)
mins[b, 0, c] = min_val
maxs[b, 0, c] = max_val
if torch.abs(max_val - min_val) < self.epsilon:
maxs[b, 0, c] = min_val + 1.0
return {"min": mins, "max": maxs}
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply min-max scaling to range [-1, 1].
"""
min_val = statistics["min"]
max_val = statistics["max"]
normalized = (data - min_val) / (max_val - min_val + self.epsilon)
return normalized * 2.0 - 1.0
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply inverse min-max scaling, now compatible with 3D or 4D tensors.
"""
min_val = statistics["min"]
max_val = statistics["max"]
if scaled_data.ndim == 4:
min_val = min_val.unsqueeze(-1)
max_val = max_val.unsqueeze(-1)
normalized = (scaled_data + 1.0) / 2.0
return normalized * (max_val - min_val + self.epsilon) + min_val
class MeanScaler(BaseScaler):
"""
A scaler that centers the data by subtracting the channel-wise mean.
This scaler only performs centering and does not affect the scale of the data.
"""
def compute_statistics(
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
) -> dict[str, torch.Tensor]:
"""
Compute the mean for each channel from historical data.
"""
batch_size, seq_len, num_channels = history_values.shape
device = history_values.device
# Initialize a tensor to store the mean for each channel in each batch item
means = torch.zeros(batch_size, 1, num_channels, device=device)
for b in range(batch_size):
for c in range(num_channels):
channel_data = history_values[b, :, c]
# Use the mask to select only valid (observed) data points
if history_mask is not None:
mask = history_mask[b, :].bool()
valid_data = channel_data[mask]
else:
valid_data = channel_data
# Skip if there's no valid data for this channel
if len(valid_data) == 0:
continue
# Filter out non-finite values like NaN or Inf before computing
valid_data = valid_data[torch.isfinite(valid_data)]
if len(valid_data) == 0:
continue
# Compute the mean and store it
means[b, 0, c] = torch.mean(valid_data)
return {"mean": means}
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply mean centering: data - mean.
"""
mean = statistics["mean"]
return data - mean
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply inverse mean centering: scaled_data + mean.
Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors.
"""
mean = statistics["mean"]
# Adjust shape for 4D tensors (batch, seq_len, channels, samples)
if scaled_data.ndim == 4:
mean = mean.unsqueeze(-1)
return scaled_data + mean
class MedianScaler(BaseScaler):
"""
A scaler that centers the data by subtracting the channel-wise median.
This scaler only performs centering and does not affect the scale of the data.
It is more robust to outliers than the MeanScaler.
"""
def compute_statistics(
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
) -> dict[str, torch.Tensor]:
"""
Compute the median for each channel from historical data.
"""
batch_size, seq_len, num_channels = history_values.shape
device = history_values.device
# Initialize a tensor to store the median for each channel in each batch item
medians = torch.zeros(batch_size, 1, num_channels, device=device)
for b in range(batch_size):
for c in range(num_channels):
channel_data = history_values[b, :, c]
# Use the mask to select only valid (observed) data points
if history_mask is not None:
mask = history_mask[b, :].bool()
valid_data = channel_data[mask]
else:
valid_data = channel_data
# Skip if there's no valid data for this channel
if len(valid_data) == 0:
continue
# Filter out non-finite values like NaN or Inf before computing
valid_data = valid_data[torch.isfinite(valid_data)]
if len(valid_data) == 0:
continue
# Compute the median and store it
medians[b, 0, c] = torch.median(valid_data)
return {"median": medians}
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply median centering: data - median.
"""
median = statistics["median"]
return data - median
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Apply inverse median centering: scaled_data + median.
Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors.
"""
median = statistics["median"]
# Adjust shape for 4D tensors (batch, seq_len, channels, samples)
if scaled_data.ndim == 4:
median = median.unsqueeze(-1)
return scaled_data + median