TempoPFN / src /synthetic_generation /abstract_classes.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
import torch
from src.data.containers import TimeSeriesContainer
from src.data.frequency import (
select_safe_random_frequency,
select_safe_start_date,
)
from src.synthetic_generation.generator_params import GeneratorParams
class AbstractTimeSeriesGenerator(ABC):
"""
Abstract base class for synthetic time series generators.
"""
@abstractmethod
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
"""
Generate synthetic time series data.
Parameters
----------
random_seed : int, optional
Random seed for reproducibility.
Returns
-------
np.ndarray
Time series values of shape (length,) for univariate or
(length, num_channels) for multivariate time series.
"""
pass
class GeneratorWrapper:
"""
Unified base class for all generator wrappers, using a GeneratorParams dataclass
for configuration. Provides parameter sampling, validation, and batch formatting utilities.
"""
def __init__(self, params: GeneratorParams):
"""
Initialize the GeneratorWrapper with a GeneratorParams dataclass.
Parameters
----------
params : GeneratorParams
Dataclass instance containing all generator configuration parameters.
"""
self.params = params
self._set_random_seeds(self.params.global_seed)
def _set_random_seeds(self, seed: int) -> None:
# For parameter sampling, we want diversity across batches even with similar seeds
# Use a hash of the generator class name to ensure different generators get different parameter sequences
param_seed = seed + hash(self.__class__.__name__) % 2**31
self.rng = np.random.default_rng(param_seed)
# Set global numpy and torch seeds for deterministic behavior in underlying generators
np.random.seed(seed)
torch.manual_seed(seed)
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
"""
Sample parameters with total_length fixed and history_length calculated.
Returns
-------
Dict[str, Any]
Dictionary containing sampled parameter values where
history_length = total_length - future_length.
"""
# Select a suitable frequency based on the total length
frequency = [select_safe_random_frequency(self.params.length, self.rng) for _ in range(batch_size)]
start = [select_safe_start_date(self.params.length, frequency[i], self.rng) for i in range(batch_size)]
return {
"frequency": frequency,
"start": start,
}
@abstractmethod
def generate_batch(self, batch_size: int, seed: int | None = None, **kwargs) -> TimeSeriesContainer:
raise NotImplementedError("Subclasses must implement generate_batch()")