|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from src.data.containers import TimeSeriesContainer |
|
|
from src.data.frequency import ( |
|
|
select_safe_random_frequency, |
|
|
select_safe_start_date, |
|
|
) |
|
|
from src.synthetic_generation.generator_params import GeneratorParams |
|
|
|
|
|
|
|
|
class AbstractTimeSeriesGenerator(ABC): |
|
|
""" |
|
|
Abstract base class for synthetic time series generators. |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray: |
|
|
""" |
|
|
Generate synthetic time series data. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
random_seed : int, optional |
|
|
Random seed for reproducibility. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
np.ndarray |
|
|
Time series values of shape (length,) for univariate or |
|
|
(length, num_channels) for multivariate time series. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class GeneratorWrapper: |
|
|
""" |
|
|
Unified base class for all generator wrappers, using a GeneratorParams dataclass |
|
|
for configuration. Provides parameter sampling, validation, and batch formatting utilities. |
|
|
""" |
|
|
|
|
|
def __init__(self, params: GeneratorParams): |
|
|
""" |
|
|
Initialize the GeneratorWrapper with a GeneratorParams dataclass. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
params : GeneratorParams |
|
|
Dataclass instance containing all generator configuration parameters. |
|
|
""" |
|
|
self.params = params |
|
|
self._set_random_seeds(self.params.global_seed) |
|
|
|
|
|
def _set_random_seeds(self, seed: int) -> None: |
|
|
|
|
|
|
|
|
param_seed = seed + hash(self.__class__.__name__) % 2**31 |
|
|
self.rng = np.random.default_rng(param_seed) |
|
|
|
|
|
|
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
def _sample_parameters(self, batch_size: int) -> dict[str, Any]: |
|
|
""" |
|
|
Sample parameters with total_length fixed and history_length calculated. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Dict[str, Any] |
|
|
Dictionary containing sampled parameter values where |
|
|
history_length = total_length - future_length. |
|
|
""" |
|
|
|
|
|
|
|
|
frequency = [select_safe_random_frequency(self.params.length, self.rng) for _ in range(batch_size)] |
|
|
start = [select_safe_start_date(self.params.length, frequency[i], self.rng) for i in range(batch_size)] |
|
|
|
|
|
return { |
|
|
"frequency": frequency, |
|
|
"start": start, |
|
|
} |
|
|
|
|
|
@abstractmethod |
|
|
def generate_batch(self, batch_size: int, seed: int | None = None, **kwargs) -> TimeSeriesContainer: |
|
|
raise NotImplementedError("Subclasses must implement generate_batch()") |
|
|
|