File size: 2,721 Bytes
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
 
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Any

import numpy as np
from src.data.containers import TimeSeriesContainer
from src.synthetic_generation.abstract_classes import GeneratorWrapper
from src.synthetic_generation.generator_params import KernelGeneratorParams
from src.synthetic_generation.kernel_synth.kernel_synth import KernelSynthGenerator


class KernelGeneratorWrapper(GeneratorWrapper):
    """
    Wrapper for KernelSynthGenerator to generate batches of multivariate time series data
    by stacking multiple univariate series. Accepts a KernelGeneratorParams dataclass for configuration.
    """

    def __init__(self, params: KernelGeneratorParams):
        super().__init__(params)
        self.params: KernelGeneratorParams = params

    def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
        """
        Sample parameter values for batch generation with KernelSynthGenerator.

        Returns
        -------
        Dict[str, Any]
            Dictionary containing sampled parameter values.
        """
        params = super()._sample_parameters(batch_size)

        params.update(
            {
                "length": self.params.length,
                "max_kernels": self.params.max_kernels,
            }
        )
        return params

    def generate_batch(
        self,
        batch_size: int,
        seed: int | None = None,
        params: dict[str, Any] | None = None,
    ) -> TimeSeriesContainer:
        """
        Generate a batch of synthetic multivariate time series using KernelSynthGenerator.

        Parameters
        ----------
        batch_size : int
            Number of time series to generate.
        seed : int, optional
            Random seed for this batch (default: None).
        params : Dict[str, Any], optional
            Pre-sampled parameters to use. If None, parameters will be sampled.

        Returns
        -------
        BatchTimeSeriesContainer
            A container with the generated time series data.
        """
        if seed is not None:
            self._set_random_seeds(seed)
        if params is None:
            params = self._sample_parameters(batch_size)

        generator = KernelSynthGenerator(
            length=params["length"],
            max_kernels=params["max_kernels"],
            random_seed=seed,
        )

        batch_values = []
        for i in range(batch_size):
            batch_seed = None if seed is None else seed + i
            values = generator.generate_time_series(random_seed=batch_seed)
            batch_values.append(values)

        return TimeSeriesContainer(
            values=np.array(batch_values),
            start=params["start"],
            frequency=params["frequency"],
        )