|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class BaseScaler(ABC): |
|
|
""" |
|
|
Abstract base class for time series scalers. |
|
|
|
|
|
Defines the interface for scaling multivariate time series data with support |
|
|
for masked values and channel-wise scaling. |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
def compute_statistics( |
|
|
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Compute scaling statistics from historical data. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply scaling transformation to data. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply inverse scaling transformation to recover original scale. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class RobustScaler(BaseScaler): |
|
|
""" |
|
|
Robust scaler using median and IQR for normalization. |
|
|
""" |
|
|
|
|
|
def __init__(self, epsilon: float = 1e-6, min_scale: float = 1e-3): |
|
|
if epsilon <= 0: |
|
|
raise ValueError("epsilon must be positive") |
|
|
if min_scale <= 0: |
|
|
raise ValueError("min_scale must be positive") |
|
|
self.epsilon = epsilon |
|
|
self.min_scale = min_scale |
|
|
|
|
|
def compute_statistics( |
|
|
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Compute median and IQR statistics from historical data with improved numerical stability. |
|
|
""" |
|
|
batch_size, seq_len, num_channels = history_values.shape |
|
|
device = history_values.device |
|
|
|
|
|
medians = torch.zeros(batch_size, 1, num_channels, device=device) |
|
|
iqrs = torch.ones(batch_size, 1, num_channels, device=device) |
|
|
|
|
|
for b in range(batch_size): |
|
|
for c in range(num_channels): |
|
|
channel_data = history_values[b, :, c] |
|
|
|
|
|
if history_mask is not None: |
|
|
mask = history_mask[b, :].bool() |
|
|
valid_data = channel_data[mask] |
|
|
else: |
|
|
valid_data = channel_data |
|
|
|
|
|
if len(valid_data) == 0: |
|
|
continue |
|
|
|
|
|
valid_data = valid_data[torch.isfinite(valid_data)] |
|
|
|
|
|
if len(valid_data) == 0: |
|
|
continue |
|
|
|
|
|
median_val = torch.median(valid_data) |
|
|
medians[b, 0, c] = median_val |
|
|
|
|
|
if len(valid_data) > 1: |
|
|
try: |
|
|
q75 = torch.quantile(valid_data, 0.75) |
|
|
q25 = torch.quantile(valid_data, 0.25) |
|
|
iqr_val = q75 - q25 |
|
|
iqr_val = torch.max(iqr_val, torch.tensor(self.min_scale, device=device)) |
|
|
iqrs[b, 0, c] = iqr_val |
|
|
except Exception: |
|
|
std_val = torch.std(valid_data) |
|
|
iqrs[b, 0, c] = torch.max(std_val, torch.tensor(self.min_scale, device=device)) |
|
|
else: |
|
|
iqrs[b, 0, c] = self.min_scale |
|
|
|
|
|
return {"median": medians, "iqr": iqrs} |
|
|
|
|
|
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply robust scaling: (data - median) / (iqr + epsilon). |
|
|
""" |
|
|
median = statistics["median"] |
|
|
iqr = statistics["iqr"] |
|
|
|
|
|
denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)) |
|
|
scaled_data = (data - median) / denominator |
|
|
scaled_data = torch.clamp(scaled_data, -50.0, 50.0) |
|
|
|
|
|
return scaled_data |
|
|
|
|
|
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply inverse robust scaling, now compatible with 3D or 4D tensors. |
|
|
""" |
|
|
median = statistics["median"] |
|
|
iqr = statistics["iqr"] |
|
|
|
|
|
denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)) |
|
|
|
|
|
if scaled_data.ndim == 4: |
|
|
denominator = denominator.unsqueeze(-1) |
|
|
median = median.unsqueeze(-1) |
|
|
|
|
|
return scaled_data * denominator + median |
|
|
|
|
|
|
|
|
class MinMaxScaler(BaseScaler): |
|
|
""" |
|
|
Min-Max scaler that normalizes data to the range [-1, 1]. |
|
|
""" |
|
|
|
|
|
def __init__(self, epsilon: float = 1e-8): |
|
|
if epsilon <= 0: |
|
|
raise ValueError("epsilon must be positive") |
|
|
self.epsilon = epsilon |
|
|
|
|
|
def compute_statistics( |
|
|
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Compute min and max statistics from historical data. |
|
|
""" |
|
|
batch_size, seq_len, num_channels = history_values.shape |
|
|
device = history_values.device |
|
|
|
|
|
mins = torch.zeros(batch_size, 1, num_channels, device=device) |
|
|
maxs = torch.ones(batch_size, 1, num_channels, device=device) |
|
|
|
|
|
for b in range(batch_size): |
|
|
for c in range(num_channels): |
|
|
channel_data = history_values[b, :, c] |
|
|
|
|
|
if history_mask is not None: |
|
|
mask = history_mask[b, :].bool() |
|
|
valid_data = channel_data[mask] |
|
|
else: |
|
|
valid_data = channel_data |
|
|
|
|
|
if len(valid_data) == 0: |
|
|
continue |
|
|
|
|
|
min_val = torch.min(valid_data) |
|
|
max_val = torch.max(valid_data) |
|
|
|
|
|
mins[b, 0, c] = min_val |
|
|
maxs[b, 0, c] = max_val |
|
|
|
|
|
if torch.abs(max_val - min_val) < self.epsilon: |
|
|
maxs[b, 0, c] = min_val + 1.0 |
|
|
|
|
|
return {"min": mins, "max": maxs} |
|
|
|
|
|
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply min-max scaling to range [-1, 1]. |
|
|
""" |
|
|
min_val = statistics["min"] |
|
|
max_val = statistics["max"] |
|
|
|
|
|
normalized = (data - min_val) / (max_val - min_val + self.epsilon) |
|
|
return normalized * 2.0 - 1.0 |
|
|
|
|
|
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply inverse min-max scaling, now compatible with 3D or 4D tensors. |
|
|
""" |
|
|
min_val = statistics["min"] |
|
|
max_val = statistics["max"] |
|
|
|
|
|
if scaled_data.ndim == 4: |
|
|
min_val = min_val.unsqueeze(-1) |
|
|
max_val = max_val.unsqueeze(-1) |
|
|
|
|
|
normalized = (scaled_data + 1.0) / 2.0 |
|
|
return normalized * (max_val - min_val + self.epsilon) + min_val |
|
|
|
|
|
|
|
|
class MeanScaler(BaseScaler): |
|
|
""" |
|
|
A scaler that centers the data by subtracting the channel-wise mean. |
|
|
|
|
|
This scaler only performs centering and does not affect the scale of the data. |
|
|
""" |
|
|
|
|
|
def compute_statistics( |
|
|
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Compute the mean for each channel from historical data. |
|
|
""" |
|
|
batch_size, seq_len, num_channels = history_values.shape |
|
|
device = history_values.device |
|
|
|
|
|
|
|
|
means = torch.zeros(batch_size, 1, num_channels, device=device) |
|
|
|
|
|
for b in range(batch_size): |
|
|
for c in range(num_channels): |
|
|
channel_data = history_values[b, :, c] |
|
|
|
|
|
|
|
|
if history_mask is not None: |
|
|
mask = history_mask[b, :].bool() |
|
|
valid_data = channel_data[mask] |
|
|
else: |
|
|
valid_data = channel_data |
|
|
|
|
|
|
|
|
if len(valid_data) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
valid_data = valid_data[torch.isfinite(valid_data)] |
|
|
|
|
|
if len(valid_data) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
means[b, 0, c] = torch.mean(valid_data) |
|
|
|
|
|
return {"mean": means} |
|
|
|
|
|
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply mean centering: data - mean. |
|
|
""" |
|
|
mean = statistics["mean"] |
|
|
return data - mean |
|
|
|
|
|
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply inverse mean centering: scaled_data + mean. |
|
|
|
|
|
Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors. |
|
|
""" |
|
|
mean = statistics["mean"] |
|
|
|
|
|
|
|
|
if scaled_data.ndim == 4: |
|
|
mean = mean.unsqueeze(-1) |
|
|
|
|
|
return scaled_data + mean |
|
|
|
|
|
|
|
|
class MedianScaler(BaseScaler): |
|
|
""" |
|
|
A scaler that centers the data by subtracting the channel-wise median. |
|
|
|
|
|
This scaler only performs centering and does not affect the scale of the data. |
|
|
It is more robust to outliers than the MeanScaler. |
|
|
""" |
|
|
|
|
|
def compute_statistics( |
|
|
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Compute the median for each channel from historical data. |
|
|
""" |
|
|
batch_size, seq_len, num_channels = history_values.shape |
|
|
device = history_values.device |
|
|
|
|
|
|
|
|
medians = torch.zeros(batch_size, 1, num_channels, device=device) |
|
|
|
|
|
for b in range(batch_size): |
|
|
for c in range(num_channels): |
|
|
channel_data = history_values[b, :, c] |
|
|
|
|
|
|
|
|
if history_mask is not None: |
|
|
mask = history_mask[b, :].bool() |
|
|
valid_data = channel_data[mask] |
|
|
else: |
|
|
valid_data = channel_data |
|
|
|
|
|
|
|
|
if len(valid_data) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
valid_data = valid_data[torch.isfinite(valid_data)] |
|
|
|
|
|
if len(valid_data) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
medians[b, 0, c] = torch.median(valid_data) |
|
|
|
|
|
return {"median": medians} |
|
|
|
|
|
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply median centering: data - median. |
|
|
""" |
|
|
median = statistics["median"] |
|
|
return data - median |
|
|
|
|
|
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
""" |
|
|
Apply inverse median centering: scaled_data + median. |
|
|
|
|
|
Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors. |
|
|
""" |
|
|
median = statistics["median"] |
|
|
|
|
|
|
|
|
if scaled_data.ndim == 4: |
|
|
median = median.unsqueeze(-1) |
|
|
|
|
|
return scaled_data + median |
|
|
|