File size: 7,753 Bytes
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
 
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
0a58567
c4b87d2
0a58567
c4b87d2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import functools

import gpytorch
import numpy as np
import torch
from src.data.frequency import FREQUENCY_MAPPING
from src.synthetic_generation.generator_params import GPGeneratorParams
from src.synthetic_generation.gp_prior.constants import (
    KERNEL_BANK,
    KERNEL_PERIODS_BY_FREQ,
)
from src.synthetic_generation.gp_prior.utils import (
    create_kernel,
    extract_periodicities,
    random_binary_map,
)
from src.synthetic_generation.utils import generate_peak_spikes


class GPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, mean_module, kernel):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = mean_module
        self.covar_module = kernel

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class GPGenerator:
    def __init__(
        self,
        params: GPGeneratorParams,
        length: int = 1024,
        random_seed: int | None = None,
    ):
        self.params = params
        self.length = length
        self.rng = np.random.default_rng(random_seed)
        self.frequency = params.frequency
        self.max_kernels = params.max_kernels
        self.likelihood_noise_level = params.likelihood_noise_level
        self.noise_level = params.noise_level
        self.use_original_gp = params.use_original_gp
        self.gaussians_periodic = params.gaussians_periodic
        self.peak_spike_ratio = params.peak_spike_ratio
        self.subfreq_ratio = params.subfreq_ratio
        self.periods_per_freq = params.periods_per_freq
        self.gaussian_sampling_ratio = params.gaussian_sampling_ratio
        self.kernel_periods = params.kernel_periods
        self.max_period_ratio = params.max_period_ratio
        self.kernel_bank = params.kernel_bank

    def generate_time_series(
        self,
        random_seed: int | None = None,
    ) -> dict[str, np.ndarray]:
        with torch.inference_mode():
            if random_seed is not None:
                self.rng = np.random.default_rng(random_seed)
                torch.manual_seed(random_seed)

            # Determine kernel_bank and gaussians_periodic
            if self.use_original_gp:
                kernel_bank = KERNEL_BANK
                gaussians_periodic = False
            else:
                # Convert kernel_bank from {str: float} format to {int: (str, float)} format
                kernel_bank = {
                    i: (kernel_name, weight) for i, (kernel_name, weight) in enumerate(self.kernel_bank.items())
                }
                gaussians_periodic = self.gaussians_periodic

            # Map frequency to freq and subfreq
            freq, subfreq, timescale = FREQUENCY_MAPPING.get(self.frequency, ("D", "", 0))

            # Decide if using exact frequencies
            exact_freqs = self.rng.random() < self.periods_per_freq
            if exact_freqs and freq in KERNEL_PERIODS_BY_FREQ:
                kernel_periods = KERNEL_PERIODS_BY_FREQ[freq]
                if subfreq:
                    subfreq_int = int(subfreq)
                    kernel_periods = [p // subfreq_int for p in kernel_periods if p >= subfreq_int]
            else:
                kernel_periods = self.kernel_periods

            # Sample number of kernels
            num_kernels = self.rng.integers(1, self.max_kernels + 1)
            # Always expect kernel_bank as dict {int: (str, float)}
            kernel_weights = np.array([v[1] for v in kernel_bank.values()])
            kernel_ids = self.rng.choice(
                list(kernel_bank.keys()),
                size=num_kernels,
                p=kernel_weights / kernel_weights.sum(),
            )
            kernel_names = [kernel_bank[i][0] for i in kernel_ids]

            # Create composite kernel
            composite_kernel = functools.reduce(
                lambda a, b: random_binary_map(a, b, rng=self.rng),
                [
                    create_kernel(
                        k,
                        self.length,
                        int(self.max_period_ratio * self.length),
                        gaussians_periodic,
                        kernel_periods,
                        rng=self.rng,
                    )
                    for k in kernel_names
                ],
            )

            # Set up GP model
            train_x = torch.linspace(0, 1, self.length)
            trend = self.rng.choice([True, False])
            mean_module = gpytorch.means.LinearMean(input_size=1) if trend else gpytorch.means.ConstantMean()
            likelihood = gpytorch.likelihoods.GaussianLikelihood(
                noise_covar=torch.diag(torch.full_like(train_x, self.likelihood_noise_level**2))
            )
            model = GPModel(train_x, None, likelihood, mean_module, composite_kernel)

            # Determine noise level
            noise = {"high": 1e-1, "moderate": 1e-2, "low": 1e-3}.get(
                self.noise_level,
                self.rng.choice([1e-1, 1e-2, 1e-3], p=[0.1, 0.2, 0.7]),
            )

            # Sample from GP prior with robust error handling
            model.eval()
            max_retries = 3
            for attempt in range(max_retries):
                try:
                    with (
                        torch.no_grad(),
                        gpytorch.settings.fast_pred_var(),
                        gpytorch.settings.cholesky_jitter(
                            max(noise * (10**attempt), 1e-4)
                        ),  # Increase jitter on retries, with a minimum floor
                        gpytorch.settings.max_cholesky_size(self.length),  # Limit decomposition size
                    ):
                        y_sample = model(train_x).sample().numpy()
                        # y_sample shape: (self.length,) (should be 1D)
                    break
                except (RuntimeError, IndexError) as e:
                    if attempt == max_retries - 1:
                        # If all attempts fail, generate a simple fallback
                        print(f"GP sampling failed after {max_retries} attempts: {e}")
                        print("Generating fallback sample with simpler kernel")
                        # Create a simple RBF kernel as fallback
                        simple_kernel = gpytorch.kernels.RBFKernel()
                        simple_model = GPModel(train_x, None, likelihood, mean_module, simple_kernel)
                        simple_model.eval()
                        with torch.no_grad():
                            y_sample = simple_model(train_x).sample().numpy()
                        break
                    else:
                        print(f"GP sampling attempt {attempt + 1} failed: {e}. Retrying with higher jitter...")

            # Optionally add peak spikes
            if self.rng.random() < self.peak_spike_ratio:
                periodicities = extract_periodicities(composite_kernel, self.length)
                if len(periodicities) > 0:
                    p = int(np.round(max(periodicities)))
                    spikes_type = self.rng.choice(["regular", "patchy"], p=[0.3, 0.7])
                    spikes = generate_peak_spikes(self.length, p, spikes_type=spikes_type)
                    # y_sample is 1D, so use y_sample[:p].argmax()
                    spikes_shift = p - y_sample[:p].argmax() if p > 0 and p <= len(y_sample) else 0
                    spikes = np.roll(spikes, -spikes_shift)
                    if spikes.max() < 0:
                        y_sample = y_sample + spikes + 1
                    else:
                        y_sample = y_sample * spikes

            return y_sample