|
|
import logging |
|
|
import math |
|
|
from collections import Counter |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from joblib import Parallel, delayed |
|
|
from torch.quasirandom import SobolEngine |
|
|
|
|
|
from src.gift_eval.data import Dataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def find_consecutive_nan_lengths(series: np.ndarray) -> list[int]: |
|
|
"""Finds the lengths of all consecutive NaN blocks in a 1D array.""" |
|
|
if series.ndim > 1: |
|
|
|
|
|
series = series.flatten() |
|
|
|
|
|
is_nan = np.isnan(series) |
|
|
padded_is_nan = np.concatenate(([False], is_nan, [False])) |
|
|
diffs = np.diff(padded_is_nan.astype(int)) |
|
|
|
|
|
start_indices = np.where(diffs == 1)[0] |
|
|
end_indices = np.where(diffs == -1)[0] |
|
|
|
|
|
return (end_indices - start_indices).tolist() |
|
|
|
|
|
|
|
|
def analyze_datasets_for_augmentation(gift_eval_path_str: str) -> dict: |
|
|
""" |
|
|
Analyzes all datasets to derive statistics needed for NaN augmentation. |
|
|
This version collects the full distribution of NaN ratios. |
|
|
""" |
|
|
logger.info("--- Starting Dataset Analysis for Augmentation (Full Distribution) ---") |
|
|
path = Path(gift_eval_path_str) |
|
|
if not path.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"Provided raw data path for augmentation analysis does not exist: {gift_eval_path_str}" |
|
|
) |
|
|
|
|
|
dataset_names = [] |
|
|
for dataset_dir in path.iterdir(): |
|
|
if dataset_dir.name.startswith(".") or not dataset_dir.is_dir(): |
|
|
continue |
|
|
freq_dirs = [d for d in dataset_dir.iterdir() if d.is_dir()] |
|
|
if freq_dirs: |
|
|
for freq_dir in freq_dirs: |
|
|
dataset_names.append(f"{dataset_dir.name}/{freq_dir.name}") |
|
|
else: |
|
|
dataset_names.append(dataset_dir.name) |
|
|
|
|
|
total_series_count = 0 |
|
|
series_with_nans_count = 0 |
|
|
nan_ratio_distribution = [] |
|
|
all_consecutive_nan_lengths = Counter() |
|
|
|
|
|
for ds_name in sorted(dataset_names): |
|
|
try: |
|
|
ds = Dataset(name=ds_name, term="short", to_univariate=False) |
|
|
for series_data in ds.training_dataset: |
|
|
total_series_count += 1 |
|
|
target = np.atleast_1d(series_data["target"]) |
|
|
num_nans = np.isnan(target).sum() |
|
|
|
|
|
if num_nans > 0: |
|
|
series_with_nans_count += 1 |
|
|
nan_ratio = num_nans / target.size |
|
|
nan_ratio_distribution.append(float(nan_ratio)) |
|
|
|
|
|
nan_lengths = find_consecutive_nan_lengths(target) |
|
|
all_consecutive_nan_lengths.update(nan_lengths) |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not process {ds_name} for augmentation analysis: {e}") |
|
|
|
|
|
if total_series_count == 0: |
|
|
raise ValueError("No series were found during augmentation analysis. Check dataset path.") |
|
|
|
|
|
p_series_has_nan = series_with_nans_count / total_series_count if total_series_count > 0 else 0 |
|
|
|
|
|
logger.info("--- Augmentation Analysis Complete ---") |
|
|
|
|
|
logger.info(f"Total series analyzed: {total_series_count}") |
|
|
logger.info(f"Series with NaNs: {series_with_nans_count} ({p_series_has_nan:.4f})") |
|
|
logger.info(f"NaN ratio distribution: {Counter(nan_ratio_distribution)}") |
|
|
logger.info(f"Consecutive NaN lengths distribution: {all_consecutive_nan_lengths}") |
|
|
logger.info("--- End of Dataset Analysis for Augmentation ---") |
|
|
return { |
|
|
"p_series_has_nan": p_series_has_nan, |
|
|
"nan_ratio_distribution": nan_ratio_distribution, |
|
|
"nan_length_distribution": all_consecutive_nan_lengths, |
|
|
} |
|
|
|
|
|
|
|
|
class NanAugmenter: |
|
|
""" |
|
|
Applies realistic NaN augmentation by generating and caching NaN patterns on-demand |
|
|
during the first transform call for a given data shape. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
p_series_has_nan: float, |
|
|
nan_ratio_distribution: list[float], |
|
|
nan_length_distribution: Counter, |
|
|
num_patterns: int = 100000, |
|
|
n_jobs: int = -1, |
|
|
nan_patterns_path: str | None = None, |
|
|
): |
|
|
""" |
|
|
Initializes the augmenter. NaN patterns are not generated at this stage. |
|
|
|
|
|
Args: |
|
|
p_series_has_nan (float): Probability that a series in a batch will be augmented. |
|
|
nan_ratio_distribution (List[float]): A list of NaN ratios observed in the dataset. |
|
|
nan_length_distribution (Counter): A Counter of consecutive NaN block lengths. |
|
|
num_patterns (int): The number of unique NaN patterns to generate per data shape. |
|
|
n_jobs (int): The number of CPU cores to use for parallel pattern generation (-1 for all cores). |
|
|
""" |
|
|
self.p_series_has_nan = p_series_has_nan |
|
|
self.nan_ratio_distribution = nan_ratio_distribution |
|
|
self.num_patterns = num_patterns |
|
|
self.n_jobs = n_jobs |
|
|
self.max_length = 2048 |
|
|
self.nan_patterns_path = nan_patterns_path |
|
|
|
|
|
self.pattern_cache: dict[tuple[int, ...], torch.BoolTensor] = {} |
|
|
|
|
|
if not nan_length_distribution or sum(nan_length_distribution.values()) == 0: |
|
|
self._has_block_distribution = False |
|
|
logger.warning("NaN length distribution is empty. Augmentation disabled.") |
|
|
else: |
|
|
self._has_block_distribution = True |
|
|
total_blocks = sum(nan_length_distribution.values()) |
|
|
self.dist_lengths = [int(i) for i in nan_length_distribution.keys()] |
|
|
self.dist_probs = [count / total_blocks for count in nan_length_distribution.values()] |
|
|
|
|
|
if not self.nan_ratio_distribution: |
|
|
logger.warning("NaN ratio distribution is empty. Augmentation disabled.") |
|
|
|
|
|
|
|
|
self._load_existing_patterns() |
|
|
|
|
|
def _load_existing_patterns(self): |
|
|
"""Load existing NaN patterns from disk if they exist.""" |
|
|
|
|
|
explicit_path: Path | None = ( |
|
|
Path(self.nan_patterns_path).resolve() if self.nan_patterns_path is not None else None |
|
|
) |
|
|
|
|
|
candidate_files: list[Path] = [] |
|
|
if explicit_path is not None: |
|
|
|
|
|
if explicit_path.is_file(): |
|
|
candidate_files.append(explicit_path) |
|
|
|
|
|
explicit_dir = explicit_path.parent |
|
|
explicit_dir.mkdir(exist_ok=True, parents=True) |
|
|
candidate_files.extend(list(explicit_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))) |
|
|
else: |
|
|
|
|
|
data_dir = Path("data") |
|
|
data_dir.mkdir(exist_ok=True) |
|
|
candidate_files.extend(list(data_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))) |
|
|
|
|
|
|
|
|
seen: set[str] = set() |
|
|
unique_candidates: list[Path] = [] |
|
|
for f in candidate_files: |
|
|
key = str(f.resolve()) |
|
|
if key not in seen: |
|
|
seen.add(key) |
|
|
unique_candidates.append(f) |
|
|
|
|
|
for pattern_file in unique_candidates: |
|
|
try: |
|
|
|
|
|
filename = pattern_file.stem |
|
|
parts = filename.split("_") |
|
|
if len(parts) >= 4: |
|
|
num_channels = int(parts[-1]) |
|
|
|
|
|
|
|
|
patterns = torch.load(pattern_file, map_location="cpu") |
|
|
cache_key = (self.max_length, num_channels) |
|
|
self.pattern_cache[cache_key] = patterns |
|
|
|
|
|
logger.info(f"Loaded {patterns.shape[0]} patterns for shape {cache_key} from {pattern_file}") |
|
|
except (ValueError, RuntimeError, FileNotFoundError) as e: |
|
|
logger.warning(f"Failed to load patterns from {pattern_file}: {e}") |
|
|
|
|
|
def _get_pattern_file_path(self, num_channels: int) -> Path: |
|
|
"""Resolve the target file path for storing/loading patterns for a given channel count.""" |
|
|
|
|
|
if self.nan_patterns_path is not None: |
|
|
base_dir = Path(self.nan_patterns_path).resolve().parent |
|
|
base_dir.mkdir(exist_ok=True, parents=True) |
|
|
else: |
|
|
base_dir = Path("data").resolve() |
|
|
base_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
return base_dir / f"nan_patterns_{self.max_length}_{num_channels}.pt" |
|
|
|
|
|
def _generate_nan_mask(self, series_shape: tuple[int, ...]) -> np.ndarray: |
|
|
"""Generates a single boolean NaN mask for a given series shape.""" |
|
|
series_size = int(np.prod(series_shape)) |
|
|
sampled_ratio = np.random.choice(self.nan_ratio_distribution) |
|
|
n_nans_to_add = int(round(series_size * sampled_ratio)) |
|
|
|
|
|
if n_nans_to_add == 0: |
|
|
return np.zeros(series_shape, dtype=bool) |
|
|
|
|
|
mask_flat = np.zeros(series_size, dtype=bool) |
|
|
nans_added = 0 |
|
|
max_attempts = n_nans_to_add * 2 |
|
|
attempts = 0 |
|
|
while nans_added < n_nans_to_add and attempts < max_attempts: |
|
|
attempts += 1 |
|
|
block_length = np.random.choice(self.dist_lengths, p=self.dist_probs) |
|
|
|
|
|
if nans_added + block_length > n_nans_to_add: |
|
|
block_length = n_nans_to_add - nans_added |
|
|
if block_length <= 0: |
|
|
break |
|
|
|
|
|
nan_counts_in_window = np.convolve(mask_flat, np.ones(block_length), mode="valid") |
|
|
valid_starts = np.where(nan_counts_in_window == 0)[0] |
|
|
|
|
|
if valid_starts.size == 0: |
|
|
continue |
|
|
|
|
|
start_pos = np.random.choice(valid_starts) |
|
|
mask_flat[start_pos : start_pos + block_length] = True |
|
|
nans_added += block_length |
|
|
|
|
|
return mask_flat.reshape(series_shape) |
|
|
|
|
|
def _pregenerate_patterns(self, series_shape: tuple[int, ...]) -> torch.BoolTensor: |
|
|
"""Uses joblib to parallelize the generation of NaN masks for a given shape.""" |
|
|
if not self._has_block_distribution or not self.nan_ratio_distribution: |
|
|
return torch.empty(0, *series_shape, dtype=torch.bool) |
|
|
|
|
|
logger.info(f"Generating {self.num_patterns} NaN patterns for shape {series_shape}...") |
|
|
|
|
|
with Parallel(n_jobs=self.n_jobs, backend="loky") as parallel: |
|
|
masks_list = parallel(delayed(self._generate_nan_mask)(series_shape) for _ in range(self.num_patterns)) |
|
|
|
|
|
logger.info(f"Pattern generation complete for shape {series_shape}.") |
|
|
return torch.from_numpy(np.stack(masks_list)).bool() |
|
|
|
|
|
def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Applies NaN patterns to a batch, generating them on-demand if the shape is new. |
|
|
""" |
|
|
if self.p_series_has_nan == 0: |
|
|
return time_series_batch |
|
|
|
|
|
history_length, num_channels = time_series_batch.shape[1:] |
|
|
assert history_length <= self.max_length, ( |
|
|
f"History length {history_length} exceeds maximum allowed {self.max_length}." |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.max_length, |
|
|
num_channels, |
|
|
) not in self.pattern_cache: |
|
|
|
|
|
target_file = self._get_pattern_file_path(num_channels) |
|
|
if target_file.exists(): |
|
|
try: |
|
|
patterns = torch.load(target_file, map_location="cpu") |
|
|
self.pattern_cache[(self.max_length, num_channels)] = patterns |
|
|
logger.info(f"Loaded NaN patterns from {target_file} for shape {(self.max_length, num_channels)}") |
|
|
except (RuntimeError, FileNotFoundError): |
|
|
|
|
|
patterns = self._pregenerate_patterns((self.max_length, num_channels)) |
|
|
torch.save(patterns, target_file) |
|
|
self.pattern_cache[(self.max_length, num_channels)] = patterns |
|
|
logger.info(f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}") |
|
|
else: |
|
|
patterns = self._pregenerate_patterns((self.max_length, num_channels)) |
|
|
torch.save(patterns, target_file) |
|
|
self.pattern_cache[(self.max_length, num_channels)] = patterns |
|
|
logger.info(f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}") |
|
|
patterns = self.pattern_cache[(self.max_length, num_channels)][:, :history_length, :] |
|
|
|
|
|
|
|
|
if patterns.numel() == 0: |
|
|
return time_series_batch |
|
|
|
|
|
batch_size = time_series_batch.shape[0] |
|
|
device = time_series_batch.device |
|
|
|
|
|
|
|
|
augment_mask = torch.rand(batch_size, device=device) < self.p_series_has_nan |
|
|
indices_to_augment = torch.where(augment_mask)[0] |
|
|
num_to_augment = indices_to_augment.numel() |
|
|
|
|
|
if num_to_augment == 0: |
|
|
return time_series_batch |
|
|
|
|
|
|
|
|
pattern_indices = torch.randint(0, patterns.shape[0], (num_to_augment,), device=device) |
|
|
|
|
|
selected_patterns = patterns[pattern_indices].to(device) |
|
|
|
|
|
time_series_batch[indices_to_augment] = time_series_batch[indices_to_augment].masked_fill( |
|
|
selected_patterns, float("nan") |
|
|
) |
|
|
|
|
|
return time_series_batch |
|
|
|
|
|
|
|
|
class CensorAugmenter: |
|
|
""" |
|
|
Applies censor augmentation by clipping values from above, below, or both. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initializes the CensorAugmenter.""" |
|
|
pass |
|
|
|
|
|
def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Applies a vectorized censor augmentation to a batch of time series. |
|
|
""" |
|
|
batch_size, seq_len, num_channels = time_series_batch.shape |
|
|
assert num_channels == 1 |
|
|
time_series_batch = time_series_batch.squeeze(-1) |
|
|
with torch.no_grad(): |
|
|
batch_size, seq_len = time_series_batch.shape |
|
|
device = time_series_batch.device |
|
|
|
|
|
|
|
|
op_mode = torch.randint(0, 3, (batch_size, 1), device=device) |
|
|
|
|
|
|
|
|
q1 = torch.rand(batch_size, device=device) |
|
|
q2 = torch.rand(batch_size, device=device) |
|
|
q_low = torch.minimum(q1, q2) |
|
|
q_high = torch.maximum(q1, q2) |
|
|
|
|
|
sorted_series = torch.sort(time_series_batch, dim=1).values |
|
|
indices_low = (q_low * (seq_len - 1)).long() |
|
|
indices_high = (q_high * (seq_len - 1)).long() |
|
|
|
|
|
c_low = torch.gather(sorted_series, 1, indices_low.unsqueeze(1)) |
|
|
c_high = torch.gather(sorted_series, 1, indices_high.unsqueeze(1)) |
|
|
|
|
|
|
|
|
clip_above = torch.minimum(time_series_batch, c_high) |
|
|
clip_below = torch.maximum(time_series_batch, c_low) |
|
|
|
|
|
|
|
|
result = torch.where( |
|
|
op_mode == 1, |
|
|
clip_above, |
|
|
torch.where(op_mode == 2, clip_below, time_series_batch), |
|
|
) |
|
|
augmented_batch = torch.where( |
|
|
op_mode == 0, |
|
|
time_series_batch, |
|
|
result, |
|
|
) |
|
|
|
|
|
return augmented_batch.unsqueeze(-1) |
|
|
|
|
|
|
|
|
class QuantizationAugmenter: |
|
|
""" |
|
|
Applies non-equidistant quantization using a Sobol sequence to generate |
|
|
uniformly distributed levels. This implementation is fully vectorized. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
p_quantize: float, |
|
|
level_range: tuple[int, int], |
|
|
seed: int | None = None, |
|
|
): |
|
|
""" |
|
|
Initializes the augmenter. |
|
|
|
|
|
Args: |
|
|
p_quantize (float): Probability of applying quantization to a series. |
|
|
level_range (Tuple[int, int]): Inclusive range [min, max] to sample the |
|
|
number of quantization levels from. |
|
|
seed (Optional[int]): Seed for the Sobol sequence generator for reproducibility. |
|
|
""" |
|
|
assert 0.0 <= p_quantize <= 1.0, "Probability must be between 0 and 1." |
|
|
assert level_range[0] >= 2, "Minimum number of levels must be at least 2." |
|
|
assert level_range[0] <= level_range[1], "Min levels cannot be greater than max." |
|
|
|
|
|
self.p_quantize = p_quantize |
|
|
self.level_range = level_range |
|
|
|
|
|
|
|
|
|
|
|
max_intermediate_levels = self.level_range[1] - 2 |
|
|
if max_intermediate_levels > 0: |
|
|
|
|
|
self.sobol_engine = SobolEngine(dimension=max_intermediate_levels, scramble=True, seed=seed) |
|
|
else: |
|
|
self.sobol_engine = None |
|
|
|
|
|
def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Applies augmentation in a fully vectorized way on the batch's device. |
|
|
Handles input shape (batch, length, 1). |
|
|
""" |
|
|
|
|
|
if time_series_batch.dim() == 3 and time_series_batch.shape[2] == 1: |
|
|
is_3d = True |
|
|
time_series_squeezed = time_series_batch.squeeze(-1) |
|
|
else: |
|
|
is_3d = False |
|
|
time_series_squeezed = time_series_batch |
|
|
|
|
|
if self.p_quantize == 0 or self.sobol_engine is None: |
|
|
return time_series_batch |
|
|
|
|
|
n_series, _ = time_series_squeezed.shape |
|
|
device = time_series_squeezed.device |
|
|
|
|
|
|
|
|
augment_mask = torch.rand(n_series, device=device) < self.p_quantize |
|
|
n_augment = torch.sum(augment_mask) |
|
|
if n_augment == 0: |
|
|
return time_series_batch |
|
|
|
|
|
series_to_augment = time_series_squeezed[augment_mask] |
|
|
|
|
|
|
|
|
min_l, max_l = self.level_range |
|
|
n_levels_per_series = torch.randint(min_l, max_l + 1, size=(n_augment,), device=device) |
|
|
max_levels_in_batch = n_levels_per_series.max().item() |
|
|
|
|
|
|
|
|
min_vals = torch.amin(series_to_augment, dim=1, keepdim=True) |
|
|
max_vals = torch.amax(series_to_augment, dim=1, keepdim=True) |
|
|
value_range = max_vals - min_vals |
|
|
is_flat = value_range == 0 |
|
|
|
|
|
|
|
|
num_intermediate_levels = max_levels_in_batch - 2 |
|
|
if num_intermediate_levels > 0: |
|
|
|
|
|
sobol_points = self.sobol_engine.draw(n_augment).to(device) |
|
|
|
|
|
quasi_rand_points = sobol_points[:, :num_intermediate_levels] |
|
|
else: |
|
|
|
|
|
quasi_rand_points = torch.empty(n_augment, 0, device=device) |
|
|
|
|
|
scaled_quasi_rand_levels = min_vals + value_range * quasi_rand_points |
|
|
level_values = torch.cat([min_vals, max_vals, scaled_quasi_rand_levels], dim=1) |
|
|
level_values, _ = torch.sort(level_values, dim=1) |
|
|
|
|
|
|
|
|
series_expanded = series_to_augment.unsqueeze(2) |
|
|
levels_expanded = level_values.unsqueeze(1) |
|
|
diff = torch.abs(series_expanded - levels_expanded) |
|
|
|
|
|
arange_mask = torch.arange(max_levels_in_batch, device=device).unsqueeze(0) |
|
|
valid_levels_mask = arange_mask < n_levels_per_series.unsqueeze(1) |
|
|
masked_diff = torch.where(valid_levels_mask.unsqueeze(1), diff, float("inf")) |
|
|
closest_level_indices = torch.argmin(masked_diff, dim=2) |
|
|
|
|
|
|
|
|
quantized_subset = torch.gather(level_values, 1, closest_level_indices) |
|
|
|
|
|
|
|
|
final_subset = torch.where(is_flat, series_to_augment, quantized_subset) |
|
|
|
|
|
|
|
|
augmented_batch_squeezed = time_series_squeezed.clone() |
|
|
augmented_batch_squeezed[augment_mask] = final_subset |
|
|
|
|
|
|
|
|
if is_3d: |
|
|
return augmented_batch_squeezed.unsqueeze(-1) |
|
|
else: |
|
|
return augmented_batch_squeezed |
|
|
|
|
|
|
|
|
class MixUpAugmenter: |
|
|
""" |
|
|
Applies mixup augmentation by creating a weighted average of multiple time series. |
|
|
|
|
|
This version includes an option for time-dependent mixup using Simplex Path |
|
|
Interpolation, creating a smooth transition between different mixing weights. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
max_n_series_to_combine: int = 10, |
|
|
p_combine: float = 0.4, |
|
|
p_time_dependent: float = 0.5, |
|
|
randomize_k_per_series: bool = True, |
|
|
dirichlet_alpha_range: tuple[float, float] = (0.1, 5.0), |
|
|
): |
|
|
""" |
|
|
Initializes the augmenter. |
|
|
|
|
|
Args: |
|
|
max_n_series_to_combine (int): The maximum number of series to combine. |
|
|
The actual number k will be sampled from [2, max]. |
|
|
p_combine (float): The probability of replacing a series with a combination. |
|
|
p_time_dependent (float): The probability of using the time-dependent |
|
|
simplex path method for a given mixup operation. Defaults to 0.5. |
|
|
randomize_k_per_series (bool): If True, each augmented series will be a |
|
|
combination of a different number of series (k). |
|
|
If False, one k is chosen for the whole batch. |
|
|
dirichlet_alpha_range (Tuple[float, float]): The [min, max] range to sample the |
|
|
Dirichlet 'alpha' from. A smaller alpha (e.g., 0.2) creates mixes |
|
|
dominated by one series. A larger alpha (e.g., 5.0) creates |
|
|
more uniform weights. |
|
|
""" |
|
|
assert max_n_series_to_combine >= 2, "Must combine at least 2 series." |
|
|
assert 0.0 <= p_combine <= 1.0, "p_combine must be between 0 and 1." |
|
|
assert 0.0 <= p_time_dependent <= 1.0, "p_time_dependent must be between 0 and 1." |
|
|
assert dirichlet_alpha_range[0] > 0 and dirichlet_alpha_range[0] <= dirichlet_alpha_range[1] |
|
|
self.max_k = max_n_series_to_combine |
|
|
self.p_combine = p_combine |
|
|
self.p_time_dependent = p_time_dependent |
|
|
self.randomize_k = randomize_k_per_series |
|
|
self.alpha_range = dirichlet_alpha_range |
|
|
|
|
|
def _sample_alpha(self) -> float: |
|
|
log_alpha_min = math.log10(self.alpha_range[0]) |
|
|
log_alpha_max = math.log10(self.alpha_range[1]) |
|
|
log_alpha = log_alpha_min + np.random.rand() * (log_alpha_max - log_alpha_min) |
|
|
return float(10**log_alpha) |
|
|
|
|
|
def _sample_k(self) -> int: |
|
|
return int(torch.randint(2, self.max_k + 1, (1,)).item()) |
|
|
|
|
|
def _static_mix( |
|
|
self, |
|
|
source_series: torch.Tensor, |
|
|
alpha: float, |
|
|
return_weights: bool = False, |
|
|
): |
|
|
"""Mixes k source series using a single, static set of Dirichlet weights.""" |
|
|
k = int(source_series.shape[0]) |
|
|
device = source_series.device |
|
|
concentration = torch.full((k,), float(alpha), device=device) |
|
|
weights = torch.distributions.Dirichlet(concentration).sample() |
|
|
weights_view = weights.view(k, 1, 1) |
|
|
mixed_series = (source_series * weights_view).sum(dim=0, keepdim=True) |
|
|
if return_weights: |
|
|
return mixed_series, weights |
|
|
return mixed_series |
|
|
|
|
|
def _simplex_path_mix( |
|
|
self, |
|
|
source_series: torch.Tensor, |
|
|
alpha: float, |
|
|
return_weights: bool = False, |
|
|
): |
|
|
"""Mixes k series using time-varying weights interpolated along a simplex path.""" |
|
|
k, length, _ = source_series.shape |
|
|
device = source_series.device |
|
|
|
|
|
|
|
|
concentration = torch.full((k,), float(alpha), device=device) |
|
|
dirichlet_dist = torch.distributions.Dirichlet(concentration) |
|
|
w_start = dirichlet_dist.sample() |
|
|
w_end = dirichlet_dist.sample() |
|
|
|
|
|
|
|
|
alpha_ramp = torch.linspace(0, 1, length, device=device) |
|
|
|
|
|
|
|
|
|
|
|
time_varying_weights = w_start.unsqueeze(1) * (1 - alpha_ramp.unsqueeze(0)) + w_end.unsqueeze( |
|
|
1 |
|
|
) * alpha_ramp.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
weights_view = time_varying_weights.unsqueeze(-1) |
|
|
mixed_series = (source_series * weights_view).sum(dim=0, keepdim=True) |
|
|
|
|
|
if return_weights: |
|
|
return mixed_series, time_varying_weights |
|
|
return mixed_series |
|
|
|
|
|
def transform(self, time_series_batch: torch.Tensor, return_debug_info: bool = False): |
|
|
""" |
|
|
Applies the mixup augmentation, randomly choosing between static and |
|
|
time-dependent mixing methods. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
if self.p_combine == 0: |
|
|
return (time_series_batch, {}) if return_debug_info else time_series_batch |
|
|
|
|
|
batch_size, _, _ = time_series_batch.shape |
|
|
device = time_series_batch.device |
|
|
|
|
|
if batch_size <= self.max_k: |
|
|
return (time_series_batch, {}) if return_debug_info else time_series_batch |
|
|
|
|
|
|
|
|
augment_mask = torch.rand(batch_size, device=device) < self.p_combine |
|
|
indices_to_replace = torch.where(augment_mask)[0] |
|
|
n_augment = indices_to_replace.numel() |
|
|
|
|
|
if n_augment == 0: |
|
|
return (time_series_batch, {}) if return_debug_info else time_series_batch |
|
|
|
|
|
|
|
|
if self.randomize_k: |
|
|
k_values = torch.randint(2, self.max_k + 1, (n_augment,), device=device) |
|
|
else: |
|
|
k = self._sample_k() |
|
|
k_values = torch.full((n_augment,), k, device=device) |
|
|
|
|
|
|
|
|
new_series_list = [] |
|
|
all_batch_indices = torch.arange(batch_size, device=device) |
|
|
debug_info = {} |
|
|
|
|
|
for i, target_idx in enumerate(indices_to_replace): |
|
|
current_k = k_values[i].item() |
|
|
|
|
|
|
|
|
candidate_mask = all_batch_indices != target_idx |
|
|
candidates = all_batch_indices[candidate_mask] |
|
|
perm = torch.randperm(candidates.shape[0], device=device) |
|
|
source_indices = candidates[perm[:current_k]] |
|
|
source_series = time_series_batch[source_indices] |
|
|
|
|
|
alpha = self._sample_alpha() |
|
|
mix_type = "static" |
|
|
|
|
|
|
|
|
if torch.rand(1).item() < self.p_time_dependent: |
|
|
mixed_series, weights = self._simplex_path_mix(source_series, alpha=alpha, return_weights=True) |
|
|
mix_type = "simplex" |
|
|
else: |
|
|
mixed_series, weights = self._static_mix(source_series, alpha=alpha, return_weights=True) |
|
|
|
|
|
new_series_list.append(mixed_series) |
|
|
|
|
|
if return_debug_info: |
|
|
debug_info[target_idx.item()] = { |
|
|
"source_indices": source_indices.cpu().numpy(), |
|
|
"weights": weights.cpu().numpy(), |
|
|
"alpha": alpha, |
|
|
"k": current_k, |
|
|
"mix_type": mix_type, |
|
|
} |
|
|
|
|
|
|
|
|
augmented_batch = time_series_batch.clone() |
|
|
if new_series_list: |
|
|
new_series_tensor = torch.cat(new_series_list, dim=0) |
|
|
augmented_batch[indices_to_replace] = new_series_tensor |
|
|
|
|
|
if return_debug_info: |
|
|
return augmented_batch.detach(), debug_info |
|
|
return augmented_batch.detach() |
|
|
|
|
|
|
|
|
class TimeFlipAugmenter: |
|
|
""" |
|
|
Applies time-reversal augmentation to a random subset of time series in a batch. |
|
|
""" |
|
|
|
|
|
def __init__(self, p_flip: float = 0.5): |
|
|
""" |
|
|
Initializes the TimeFlipAugmenter. |
|
|
|
|
|
Args: |
|
|
p_flip (float): The probability of flipping a single time series in the batch. |
|
|
Defaults to 0.5. |
|
|
""" |
|
|
assert 0.0 <= p_flip <= 1.0, "Probability must be between 0 and 1." |
|
|
self.p_flip = p_flip |
|
|
|
|
|
def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Applies time-reversal augmentation to a batch of time series. |
|
|
|
|
|
Args: |
|
|
time_series_batch (torch.Tensor): The input batch of time series with |
|
|
shape (batch_size, seq_len, num_channels). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The batch with some series potentially flipped. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
if self.p_flip == 0: |
|
|
return time_series_batch |
|
|
|
|
|
batch_size = time_series_batch.shape[0] |
|
|
device = time_series_batch.device |
|
|
|
|
|
|
|
|
flip_mask = torch.rand(batch_size, device=device) < self.p_flip |
|
|
indices_to_flip = torch.where(flip_mask)[0] |
|
|
|
|
|
if indices_to_flip.numel() == 0: |
|
|
return time_series_batch |
|
|
|
|
|
|
|
|
series_to_flip = time_series_batch[indices_to_flip] |
|
|
|
|
|
|
|
|
flipped_series = torch.flip(series_to_flip, dims=[1]) |
|
|
|
|
|
|
|
|
augmented_batch = time_series_batch.clone() |
|
|
augmented_batch[indices_to_flip] = flipped_series |
|
|
|
|
|
return augmented_batch |
|
|
|
|
|
|
|
|
class YFlipAugmenter: |
|
|
""" |
|
|
Applies y-reversal augmentation to a random subset of time series in a batch. |
|
|
""" |
|
|
|
|
|
def __init__(self, p_flip: float = 0.5): |
|
|
""" |
|
|
Initializes the TimeFlipAugmenter. |
|
|
|
|
|
Args: |
|
|
p_flip (float): The probability of flipping a single time series in the batch. |
|
|
Defaults to 0.5. |
|
|
""" |
|
|
assert 0.0 <= p_flip <= 1.0, "Probability must be between 0 and 1." |
|
|
self.p_flip = p_flip |
|
|
|
|
|
def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Applies time-reversal augmentation to a batch of time series. |
|
|
|
|
|
Args: |
|
|
time_series_batch (torch.Tensor): The input batch of time series with |
|
|
shape (batch_size, seq_len, num_channels). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The batch with some series potentially flipped. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
if self.p_flip == 0: |
|
|
return time_series_batch |
|
|
|
|
|
batch_size = time_series_batch.shape[0] |
|
|
device = time_series_batch.device |
|
|
|
|
|
|
|
|
flip_mask = torch.rand(batch_size, device=device) < self.p_flip |
|
|
indices_to_flip = torch.where(flip_mask)[0] |
|
|
|
|
|
if indices_to_flip.numel() == 0: |
|
|
return time_series_batch |
|
|
|
|
|
|
|
|
series_to_flip = time_series_batch[indices_to_flip] |
|
|
|
|
|
|
|
|
flipped_series = -series_to_flip |
|
|
|
|
|
|
|
|
augmented_batch = time_series_batch.clone() |
|
|
augmented_batch[indices_to_flip] = flipped_series |
|
|
|
|
|
return augmented_batch |
|
|
|
|
|
|
|
|
class DifferentialAugmenter: |
|
|
""" |
|
|
Applies calculus-inspired augmentations. This version includes up to the |
|
|
fourth derivative and uses nn.Conv1d with built-in 'reflect' padding for |
|
|
cleaner and more efficient convolutions. |
|
|
|
|
|
The Gaussian kernel size and sigma for the initial smoothing are randomly |
|
|
sampled at every transform() call from user-defined ranges. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
p_transform: float, |
|
|
gaussian_kernel_size_range: tuple[int, int] = (5, 51), |
|
|
gaussian_sigma_range: tuple[float, float] = (2.0, 20.0), |
|
|
): |
|
|
""" |
|
|
Initializes the augmenter. |
|
|
|
|
|
Args: |
|
|
p_transform (float): The probability of applying an augmentation to any given |
|
|
time series in a batch. |
|
|
gaussian_kernel_size_range (Tuple[int, int]): The [min, max] inclusive range |
|
|
for the Gaussian kernel size. |
|
|
Sizes will be forced to be odd. |
|
|
gaussian_sigma_range (Tuple[float, float]): The [min, max] inclusive range |
|
|
for the Gaussian sigma. |
|
|
""" |
|
|
self.p_transform = p_transform |
|
|
self.kernel_size_range = gaussian_kernel_size_range |
|
|
self.sigma_range = gaussian_sigma_range |
|
|
|
|
|
|
|
|
if not (self.kernel_size_range[0] <= self.kernel_size_range[1] and self.kernel_size_range[0] >= 3): |
|
|
raise ValueError("Invalid kernel size range. Ensure min <= max and min >= 3.") |
|
|
if not (self.sigma_range[0] <= self.sigma_range[1] and self.sigma_range[0] > 0): |
|
|
raise ValueError("Invalid sigma range. Ensure min <= max and min > 0.") |
|
|
|
|
|
|
|
|
self.conv_cache: dict[tuple[int, torch.device], dict[str, nn.Module]] = {} |
|
|
|
|
|
def _create_fixed_kernel_layers(self, num_channels: int, device: torch.device) -> dict: |
|
|
""" |
|
|
Creates and configures nn.Conv1d layers for fixed-kernel derivative operations. |
|
|
These layers are cached to improve performance. |
|
|
""" |
|
|
sobel_conv = nn.Conv1d( |
|
|
in_channels=num_channels, |
|
|
out_channels=num_channels, |
|
|
kernel_size=3, |
|
|
padding="same", |
|
|
padding_mode="reflect", |
|
|
groups=num_channels, |
|
|
bias=False, |
|
|
device=device, |
|
|
) |
|
|
laplace_conv = nn.Conv1d( |
|
|
in_channels=num_channels, |
|
|
out_channels=num_channels, |
|
|
kernel_size=3, |
|
|
padding="same", |
|
|
padding_mode="reflect", |
|
|
groups=num_channels, |
|
|
bias=False, |
|
|
device=device, |
|
|
) |
|
|
d3_conv = nn.Conv1d( |
|
|
in_channels=num_channels, |
|
|
out_channels=num_channels, |
|
|
kernel_size=5, |
|
|
padding="same", |
|
|
padding_mode="reflect", |
|
|
groups=num_channels, |
|
|
bias=False, |
|
|
device=device, |
|
|
) |
|
|
d4_conv = nn.Conv1d( |
|
|
in_channels=num_channels, |
|
|
out_channels=num_channels, |
|
|
kernel_size=5, |
|
|
padding="same", |
|
|
padding_mode="reflect", |
|
|
groups=num_channels, |
|
|
bias=False, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
sobel_kernel = ( |
|
|
torch.tensor([-1, 0, 1], device=device, dtype=torch.float32).view(1, 1, -1).repeat(num_channels, 1, 1) |
|
|
) |
|
|
laplace_kernel = ( |
|
|
torch.tensor([1, -2, 1], device=device, dtype=torch.float32).view(1, 1, -1).repeat(num_channels, 1, 1) |
|
|
) |
|
|
d3_kernel = ( |
|
|
torch.tensor([-1, 2, 0, -2, 1], device=device, dtype=torch.float32) |
|
|
.view(1, 1, -1) |
|
|
.repeat(num_channels, 1, 1) |
|
|
) |
|
|
d4_kernel = ( |
|
|
torch.tensor([1, -4, 6, -4, 1], device=device, dtype=torch.float32) |
|
|
.view(1, 1, -1) |
|
|
.repeat(num_channels, 1, 1) |
|
|
) |
|
|
|
|
|
sobel_conv.weight.data = sobel_kernel |
|
|
laplace_conv.weight.data = laplace_kernel |
|
|
d3_conv.weight.data = d3_kernel |
|
|
d4_conv.weight.data = d4_kernel |
|
|
|
|
|
for layer in [sobel_conv, laplace_conv, d3_conv, d4_conv]: |
|
|
layer.weight.requires_grad = False |
|
|
|
|
|
return { |
|
|
"sobel": sobel_conv, |
|
|
"laplace": laplace_conv, |
|
|
"d3": d3_conv, |
|
|
"d4": d4_conv, |
|
|
} |
|
|
|
|
|
def _create_gaussian_layer( |
|
|
self, kernel_size: int, sigma: float, num_channels: int, device: torch.device |
|
|
) -> nn.Module: |
|
|
"""Creates a single Gaussian convolution layer with the given dynamic parameters.""" |
|
|
gauss_conv = nn.Conv1d( |
|
|
in_channels=num_channels, |
|
|
out_channels=num_channels, |
|
|
kernel_size=kernel_size, |
|
|
padding="same", |
|
|
padding_mode="reflect", |
|
|
groups=num_channels, |
|
|
bias=False, |
|
|
device=device, |
|
|
) |
|
|
ax = torch.arange( |
|
|
-(kernel_size // 2), |
|
|
kernel_size // 2 + 1, |
|
|
device=device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
gauss_kernel = torch.exp(-0.5 * (ax / sigma) ** 2) |
|
|
gauss_kernel /= gauss_kernel.sum() |
|
|
gauss_kernel = gauss_kernel.view(1, 1, -1).repeat(num_channels, 1, 1) |
|
|
gauss_conv.weight.data = gauss_kernel |
|
|
gauss_conv.weight.requires_grad = False |
|
|
return gauss_conv |
|
|
|
|
|
def _rescale_signal(self, processed_signal: torch.Tensor, original_signal: torch.Tensor) -> torch.Tensor: |
|
|
"""Rescales the processed signal to match the min/max range of the original.""" |
|
|
original_min = torch.amin(original_signal, dim=2, keepdim=True) |
|
|
original_max = torch.amax(original_signal, dim=2, keepdim=True) |
|
|
processed_min = torch.amin(processed_signal, dim=2, keepdim=True) |
|
|
processed_max = torch.amax(processed_signal, dim=2, keepdim=True) |
|
|
|
|
|
original_range = original_max - original_min |
|
|
processed_range = processed_max - processed_min |
|
|
epsilon = 1e-8 |
|
|
rescaled_signal = ( |
|
|
(processed_signal - processed_min) / (processed_range + epsilon) |
|
|
) * original_range + original_min |
|
|
return torch.where(original_range < epsilon, original_signal, rescaled_signal) |
|
|
|
|
|
def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
|
|
"""Applies a random augmentation to a subset of the batch.""" |
|
|
with torch.no_grad(): |
|
|
if self.p_transform == 0: |
|
|
return time_series_batch |
|
|
|
|
|
batch_size, seq_len, num_channels = time_series_batch.shape |
|
|
device = time_series_batch.device |
|
|
|
|
|
augment_mask = torch.rand(batch_size, device=device) < self.p_transform |
|
|
indices_to_augment = torch.where(augment_mask)[0] |
|
|
num_to_augment = indices_to_augment.numel() |
|
|
|
|
|
if num_to_augment == 0: |
|
|
return time_series_batch |
|
|
|
|
|
|
|
|
min_k, max_k = self.kernel_size_range |
|
|
kernel_size = torch.randint(min_k, max_k + 1, (1,)).item() |
|
|
kernel_size = kernel_size // 2 * 2 + 1 |
|
|
|
|
|
min_s, max_s = self.sigma_range |
|
|
sigma = (min_s + (max_s - min_s) * torch.rand(1)).item() |
|
|
|
|
|
|
|
|
gauss_conv = self._create_gaussian_layer(kernel_size, sigma, num_channels, device) |
|
|
|
|
|
cache_key = (num_channels, device) |
|
|
if cache_key not in self.conv_cache: |
|
|
self.conv_cache[cache_key] = self._create_fixed_kernel_layers(num_channels, device) |
|
|
fixed_layers = self.conv_cache[cache_key] |
|
|
|
|
|
|
|
|
subset_to_augment = time_series_batch[indices_to_augment] |
|
|
subset_permuted = subset_to_augment.permute(0, 2, 1) |
|
|
|
|
|
op_choices = torch.randint(0, 6, (num_to_augment,), device=device) |
|
|
|
|
|
smoothed_subset = gauss_conv(subset_permuted) |
|
|
sobel_on_smoothed = fixed_layers["sobel"](smoothed_subset) |
|
|
laplace_on_smoothed = fixed_layers["laplace"](smoothed_subset) |
|
|
d3_on_smoothed = fixed_layers["d3"](smoothed_subset) |
|
|
d4_on_smoothed = fixed_layers["d4"](smoothed_subset) |
|
|
|
|
|
gauss_result = self._rescale_signal(smoothed_subset, subset_permuted) |
|
|
sobel_result = self._rescale_signal(sobel_on_smoothed, subset_permuted) |
|
|
laplace_result = self._rescale_signal(laplace_on_smoothed, subset_permuted) |
|
|
d3_result = self._rescale_signal(d3_on_smoothed, subset_permuted) |
|
|
d4_result = self._rescale_signal(d4_on_smoothed, subset_permuted) |
|
|
|
|
|
use_right_integral = torch.rand(num_to_augment, 1, 1, device=device) > 0.5 |
|
|
flipped_subset = torch.flip(subset_permuted, dims=[2]) |
|
|
right_integral = torch.flip(torch.cumsum(flipped_subset, dim=2), dims=[2]) |
|
|
left_integral = torch.cumsum(subset_permuted, dim=2) |
|
|
integral_result = torch.where(use_right_integral, right_integral, left_integral) |
|
|
integral_result_normalized = self._rescale_signal(integral_result, subset_permuted) |
|
|
|
|
|
|
|
|
op_choices_view = op_choices.view(-1, 1, 1) |
|
|
augmented_subset = torch.where(op_choices_view == 0, gauss_result, subset_permuted) |
|
|
augmented_subset = torch.where(op_choices_view == 1, sobel_result, augmented_subset) |
|
|
augmented_subset = torch.where(op_choices_view == 2, laplace_result, augmented_subset) |
|
|
augmented_subset = torch.where(op_choices_view == 3, integral_result_normalized, augmented_subset) |
|
|
augmented_subset = torch.where(op_choices_view == 4, d3_result, augmented_subset) |
|
|
augmented_subset = torch.where(op_choices_view == 5, d4_result, augmented_subset) |
|
|
|
|
|
augmented_subset_final = augmented_subset.permute(0, 2, 1) |
|
|
augmented_batch = time_series_batch.clone() |
|
|
augmented_batch[indices_to_augment] = augmented_subset_final |
|
|
|
|
|
return augmented_batch |
|
|
|
|
|
|
|
|
class RandomConvAugmenter: |
|
|
""" |
|
|
Applies a stack of 1-to-N random 1D convolutions to a time series batch. |
|
|
|
|
|
This augmenter is inspired by the principles of ROCKET and RandConv, |
|
|
randomizing nearly every aspect of the convolution process to create a |
|
|
highly diverse set of transformations. This version includes multiple |
|
|
kernel generation strategies, random padding modes, and optional non-linearities. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
p_transform: float = 0.5, |
|
|
kernel_size_range: tuple[int, int] = (3, 31), |
|
|
dilation_range: tuple[int, int] = (1, 8), |
|
|
layer_range: tuple[int, int] = (1, 3), |
|
|
sigma_range: tuple[float, float] = (0.5, 5.0), |
|
|
bias_range: tuple[float, float] = (-0.5, 0.5), |
|
|
): |
|
|
""" |
|
|
Initializes the augmenter. |
|
|
|
|
|
Args: |
|
|
p_transform (float): Probability of applying the augmentation to a series. |
|
|
kernel_size_range (Tuple[int, int]): [min, max] range for kernel sizes. |
|
|
Must be odd numbers. |
|
|
dilation_range (Tuple[int, int]): [min, max] range for dilation factors. |
|
|
layer_range (Tuple[int, int]): [min, max] range for the number of |
|
|
stacked convolution layers. |
|
|
sigma_range (Tuple[float, float]): [min, max] range for the sigma of |
|
|
Gaussian kernels. |
|
|
bias_range (Tuple[float, float]): [min, max] range for the bias term. |
|
|
""" |
|
|
assert kernel_size_range[0] % 2 == 1 and kernel_size_range[1] % 2 == 1, "Kernel sizes must be odd." |
|
|
|
|
|
self.p_transform = p_transform |
|
|
self.kernel_size_range = kernel_size_range |
|
|
self.dilation_range = dilation_range |
|
|
self.layer_range = layer_range |
|
|
self.sigma_range = sigma_range |
|
|
self.bias_range = bias_range |
|
|
self.padding_modes = ["reflect", "replicate", "circular"] |
|
|
|
|
|
def _rescale_signal(self, processed_signal: torch.Tensor, original_signal: torch.Tensor) -> torch.Tensor: |
|
|
"""Rescales the processed signal to match the min/max range of the original.""" |
|
|
original_min = torch.amin(original_signal, dim=-1, keepdim=True) |
|
|
original_max = torch.amax(original_signal, dim=-1, keepdim=True) |
|
|
processed_min = torch.amin(processed_signal, dim=-1, keepdim=True) |
|
|
processed_max = torch.amax(processed_signal, dim=-1, keepdim=True) |
|
|
|
|
|
original_range = original_max - original_min |
|
|
processed_range = processed_max - processed_min |
|
|
epsilon = 1e-8 |
|
|
|
|
|
is_flat = processed_range < epsilon |
|
|
|
|
|
rescaled_signal = ( |
|
|
(processed_signal - processed_min) / (processed_range + epsilon) |
|
|
) * original_range + original_min |
|
|
|
|
|
original_mean = torch.mean(original_signal, dim=-1, keepdim=True) |
|
|
flat_rescaled = original_mean.expand_as(original_signal) |
|
|
|
|
|
return torch.where(is_flat, flat_rescaled, rescaled_signal) |
|
|
|
|
|
def _apply_random_conv_stack(self, series: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Applies a randomly configured stack of convolutions to a single time series. |
|
|
|
|
|
Args: |
|
|
series (torch.Tensor): A single time series of shape (1, num_channels, seq_len). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The augmented time series. |
|
|
""" |
|
|
num_channels = series.shape[1] |
|
|
device = series.device |
|
|
|
|
|
num_layers = torch.randint(self.layer_range[0], self.layer_range[1] + 1, (1,)).item() |
|
|
|
|
|
processed_series = series |
|
|
for i in range(num_layers): |
|
|
|
|
|
k_min, k_max = self.kernel_size_range |
|
|
kernel_size = torch.randint(k_min // 2, k_max // 2 + 1, (1,)).item() * 2 + 1 |
|
|
|
|
|
|
|
|
d_min, d_max = self.dilation_range |
|
|
dilation = torch.randint(d_min, d_max + 1, (1,)).item() |
|
|
|
|
|
|
|
|
b_min, b_max = self.bias_range |
|
|
bias_val = (b_min + (b_max - b_min) * torch.rand(1)).item() |
|
|
|
|
|
|
|
|
padding_mode = np.random.choice(self.padding_modes) |
|
|
|
|
|
conv_layer = nn.Conv1d( |
|
|
in_channels=num_channels, |
|
|
out_channels=num_channels, |
|
|
kernel_size=kernel_size, |
|
|
dilation=dilation, |
|
|
padding="same", |
|
|
padding_mode=padding_mode, |
|
|
groups=num_channels, |
|
|
bias=True, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
weight_type = torch.randint(0, 4, (1,)).item() |
|
|
if weight_type == 0: |
|
|
s_min, s_max = self.sigma_range |
|
|
sigma = (s_min + (s_max - s_min) * torch.rand(1)).item() |
|
|
ax = torch.arange( |
|
|
-(kernel_size // 2), |
|
|
kernel_size // 2 + 1, |
|
|
device=device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
kernel = torch.exp(-0.5 * (ax / sigma) ** 2) |
|
|
elif weight_type == 1: |
|
|
kernel = torch.randn(kernel_size, device=device) |
|
|
elif weight_type == 2: |
|
|
coeffs = torch.randn(3, device=device) |
|
|
x_vals = torch.linspace(-1, 1, kernel_size, device=device) |
|
|
kernel = coeffs[0] * x_vals**2 + coeffs[1] * x_vals + coeffs[2] |
|
|
else: |
|
|
|
|
|
actual_kernel_size = 3 if kernel_size < 3 else kernel_size |
|
|
sobel_base = torch.tensor([-1, 0, 1], dtype=torch.float32, device=device) |
|
|
noise = torch.randn(3, device=device) * 0.1 |
|
|
noisy_sobel = sobel_base + noise |
|
|
|
|
|
pad_total = actual_kernel_size - 3 |
|
|
pad_left = pad_total // 2 |
|
|
pad_right = pad_total - pad_left |
|
|
kernel = F.pad(noisy_sobel, (pad_left, pad_right), "constant", 0) |
|
|
|
|
|
|
|
|
if torch.rand(1).item() < 0.8: |
|
|
kernel /= torch.sum(torch.abs(kernel)) + 1e-8 |
|
|
|
|
|
kernel = kernel.view(1, 1, -1).repeat(num_channels, 1, 1) |
|
|
|
|
|
conv_layer.weight.data = kernel |
|
|
conv_layer.bias.data.fill_(bias_val) |
|
|
conv_layer.weight.requires_grad = False |
|
|
conv_layer.bias.requires_grad = False |
|
|
|
|
|
|
|
|
processed_series = conv_layer(processed_series) |
|
|
|
|
|
|
|
|
if i < num_layers - 1: |
|
|
activation_type = torch.randint(0, 3, (1,)).item() |
|
|
if activation_type == 1: |
|
|
processed_series = F.relu(processed_series) |
|
|
elif activation_type == 2: |
|
|
processed_series = torch.tanh(processed_series) |
|
|
|
|
|
|
|
|
return processed_series |
|
|
|
|
|
def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor: |
|
|
"""Applies a random augmentation to a subset of the batch.""" |
|
|
with torch.no_grad(): |
|
|
if self.p_transform == 0: |
|
|
return time_series_batch |
|
|
|
|
|
batch_size, seq_len, num_channels = time_series_batch.shape |
|
|
device = time_series_batch.device |
|
|
|
|
|
augment_mask = torch.rand(batch_size, device=device) < self.p_transform |
|
|
indices_to_augment = torch.where(augment_mask)[0] |
|
|
num_to_augment = indices_to_augment.numel() |
|
|
|
|
|
if num_to_augment == 0: |
|
|
return time_series_batch |
|
|
|
|
|
subset_to_augment = time_series_batch[indices_to_augment] |
|
|
|
|
|
subset_permuted = subset_to_augment.permute(0, 2, 1) |
|
|
|
|
|
augmented_subset_list = [] |
|
|
for i in range(num_to_augment): |
|
|
original_series = subset_permuted[i : i + 1] |
|
|
augmented_series = self._apply_random_conv_stack(original_series) |
|
|
|
|
|
rescaled_series = self._rescale_signal(augmented_series.squeeze(0), original_series.squeeze(0)) |
|
|
augmented_subset_list.append(rescaled_series.unsqueeze(0)) |
|
|
|
|
|
if augmented_subset_list: |
|
|
augmented_subset = torch.cat(augmented_subset_list, dim=0) |
|
|
augmented_subset_final = augmented_subset.permute(0, 2, 1) |
|
|
|
|
|
augmented_batch = time_series_batch.clone() |
|
|
augmented_batch[indices_to_augment] = augmented_subset_final |
|
|
return augmented_batch |
|
|
else: |
|
|
return time_series_batch |
|
|
|