File size: 7,753 Bytes
c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import functools
import gpytorch
import numpy as np
import torch
from src.data.frequency import FREQUENCY_MAPPING
from src.synthetic_generation.generator_params import GPGeneratorParams
from src.synthetic_generation.gp_prior.constants import (
KERNEL_BANK,
KERNEL_PERIODS_BY_FREQ,
)
from src.synthetic_generation.gp_prior.utils import (
create_kernel,
extract_periodicities,
random_binary_map,
)
from src.synthetic_generation.utils import generate_peak_spikes
class GPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood, mean_module, kernel):
super().__init__(train_x, train_y, likelihood)
self.mean_module = mean_module
self.covar_module = kernel
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
class GPGenerator:
def __init__(
self,
params: GPGeneratorParams,
length: int = 1024,
random_seed: int | None = None,
):
self.params = params
self.length = length
self.rng = np.random.default_rng(random_seed)
self.frequency = params.frequency
self.max_kernels = params.max_kernels
self.likelihood_noise_level = params.likelihood_noise_level
self.noise_level = params.noise_level
self.use_original_gp = params.use_original_gp
self.gaussians_periodic = params.gaussians_periodic
self.peak_spike_ratio = params.peak_spike_ratio
self.subfreq_ratio = params.subfreq_ratio
self.periods_per_freq = params.periods_per_freq
self.gaussian_sampling_ratio = params.gaussian_sampling_ratio
self.kernel_periods = params.kernel_periods
self.max_period_ratio = params.max_period_ratio
self.kernel_bank = params.kernel_bank
def generate_time_series(
self,
random_seed: int | None = None,
) -> dict[str, np.ndarray]:
with torch.inference_mode():
if random_seed is not None:
self.rng = np.random.default_rng(random_seed)
torch.manual_seed(random_seed)
# Determine kernel_bank and gaussians_periodic
if self.use_original_gp:
kernel_bank = KERNEL_BANK
gaussians_periodic = False
else:
# Convert kernel_bank from {str: float} format to {int: (str, float)} format
kernel_bank = {
i: (kernel_name, weight) for i, (kernel_name, weight) in enumerate(self.kernel_bank.items())
}
gaussians_periodic = self.gaussians_periodic
# Map frequency to freq and subfreq
freq, subfreq, timescale = FREQUENCY_MAPPING.get(self.frequency, ("D", "", 0))
# Decide if using exact frequencies
exact_freqs = self.rng.random() < self.periods_per_freq
if exact_freqs and freq in KERNEL_PERIODS_BY_FREQ:
kernel_periods = KERNEL_PERIODS_BY_FREQ[freq]
if subfreq:
subfreq_int = int(subfreq)
kernel_periods = [p // subfreq_int for p in kernel_periods if p >= subfreq_int]
else:
kernel_periods = self.kernel_periods
# Sample number of kernels
num_kernels = self.rng.integers(1, self.max_kernels + 1)
# Always expect kernel_bank as dict {int: (str, float)}
kernel_weights = np.array([v[1] for v in kernel_bank.values()])
kernel_ids = self.rng.choice(
list(kernel_bank.keys()),
size=num_kernels,
p=kernel_weights / kernel_weights.sum(),
)
kernel_names = [kernel_bank[i][0] for i in kernel_ids]
# Create composite kernel
composite_kernel = functools.reduce(
lambda a, b: random_binary_map(a, b, rng=self.rng),
[
create_kernel(
k,
self.length,
int(self.max_period_ratio * self.length),
gaussians_periodic,
kernel_periods,
rng=self.rng,
)
for k in kernel_names
],
)
# Set up GP model
train_x = torch.linspace(0, 1, self.length)
trend = self.rng.choice([True, False])
mean_module = gpytorch.means.LinearMean(input_size=1) if trend else gpytorch.means.ConstantMean()
likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_covar=torch.diag(torch.full_like(train_x, self.likelihood_noise_level**2))
)
model = GPModel(train_x, None, likelihood, mean_module, composite_kernel)
# Determine noise level
noise = {"high": 1e-1, "moderate": 1e-2, "low": 1e-3}.get(
self.noise_level,
self.rng.choice([1e-1, 1e-2, 1e-3], p=[0.1, 0.2, 0.7]),
)
# Sample from GP prior with robust error handling
model.eval()
max_retries = 3
for attempt in range(max_retries):
try:
with (
torch.no_grad(),
gpytorch.settings.fast_pred_var(),
gpytorch.settings.cholesky_jitter(
max(noise * (10**attempt), 1e-4)
), # Increase jitter on retries, with a minimum floor
gpytorch.settings.max_cholesky_size(self.length), # Limit decomposition size
):
y_sample = model(train_x).sample().numpy()
# y_sample shape: (self.length,) (should be 1D)
break
except (RuntimeError, IndexError) as e:
if attempt == max_retries - 1:
# If all attempts fail, generate a simple fallback
print(f"GP sampling failed after {max_retries} attempts: {e}")
print("Generating fallback sample with simpler kernel")
# Create a simple RBF kernel as fallback
simple_kernel = gpytorch.kernels.RBFKernel()
simple_model = GPModel(train_x, None, likelihood, mean_module, simple_kernel)
simple_model.eval()
with torch.no_grad():
y_sample = simple_model(train_x).sample().numpy()
break
else:
print(f"GP sampling attempt {attempt + 1} failed: {e}. Retrying with higher jitter...")
# Optionally add peak spikes
if self.rng.random() < self.peak_spike_ratio:
periodicities = extract_periodicities(composite_kernel, self.length)
if len(periodicities) > 0:
p = int(np.round(max(periodicities)))
spikes_type = self.rng.choice(["regular", "patchy"], p=[0.3, 0.7])
spikes = generate_peak_spikes(self.length, p, spikes_type=spikes_type)
# y_sample is 1D, so use y_sample[:p].argmax()
spikes_shift = p - y_sample[:p].argmax() if p > 0 and p <= len(y_sample) else 0
spikes = np.roll(spikes, -spikes_shift)
if spikes.max() < 0:
y_sample = y_sample + spikes + 1
else:
y_sample = y_sample * spikes
return y_sample
|