vlad-moroshan commited on
Commit
1c8d125
·
1 Parent(s): c02fe3a

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +167 -0
  2. README.md +128 -2
  3. configs/example.yaml +118 -0
  4. data/nan_stats.json +0 -0
  5. examples/generate_synthetic_data.py +324 -0
  6. examples/quick_start_tempo_pfn.ipynb +286 -0
  7. examples/quick_start_tempo_pfn.py +95 -0
  8. examples/utils.py +115 -0
  9. gift_eval/submission/all_results.csv +98 -0
  10. gift_eval/submission/config.json +6 -0
  11. pyproject.toml +57 -0
  12. src/__init__.py +0 -0
  13. src/data/__init__.py +0 -0
  14. src/data/augmentations.py +1318 -0
  15. src/data/batch_composer.py +705 -0
  16. src/data/constants.py +25 -0
  17. src/data/containers.py +272 -0
  18. src/data/datasets.py +267 -0
  19. src/data/filter.py +73 -0
  20. src/data/frequency.py +538 -0
  21. src/data/loaders.py +661 -0
  22. src/data/scalers.py +360 -0
  23. src/data/time_features.py +564 -0
  24. src/data/utils.py +75 -0
  25. src/gift_eval/__init__.py +0 -0
  26. src/gift_eval/aggregate_results.py +160 -0
  27. src/gift_eval/constants.py +83 -0
  28. src/gift_eval/data.py +234 -0
  29. src/gift_eval/dataset_properties.json +152 -0
  30. src/gift_eval/evaluate.py +529 -0
  31. src/gift_eval/model_wrapper.py +349 -0
  32. src/models/__init__.py +0 -0
  33. src/models/blocks.py +58 -0
  34. src/models/model.py +427 -0
  35. src/optim/lr_scheduler.py +360 -0
  36. src/plotting/__init__.py +0 -0
  37. src/plotting/gift_eval_utils.py +215 -0
  38. src/plotting/plot_timeseries.py +292 -0
  39. src/synthetic_generation/__init__.py +0 -0
  40. src/synthetic_generation/abstract_classes.py +97 -0
  41. src/synthetic_generation/anomalies/anomaly_generator.py +293 -0
  42. src/synthetic_generation/anomalies/anomaly_generator_wrapper.py +64 -0
  43. src/synthetic_generation/audio_generators/financial_volatility_generator.py +103 -0
  44. src/synthetic_generation/audio_generators/financial_volatility_wrapper.py +91 -0
  45. src/synthetic_generation/audio_generators/multi_scale_fractal_generator.py +75 -0
  46. src/synthetic_generation/audio_generators/multi_scale_fractal_wrapper.py +77 -0
  47. src/synthetic_generation/audio_generators/network_topology_generator.py +113 -0
  48. src/synthetic_generation/audio_generators/network_topology_wrapper.py +93 -0
  49. src/synthetic_generation/audio_generators/stochastic_rhythm_generator.py +86 -0
  50. src/synthetic_generation/audio_generators/stochastic_rhythm_wrapper.py +81 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ logs/
2
+ *.png
3
+ *.pth
4
+ # *.sh
5
+ *.slurm
6
+ *.pkl
7
+
8
+ wandb/
9
+ AutogluonModels/
10
+ .vscode/
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # PyBuilder
66
+ .pybuilder/
67
+ target/
68
+
69
+ # Jupyter Notebook
70
+ .ipynb_checkpoints
71
+
72
+ # IPython
73
+ profile_default/
74
+ ipython_config.py
75
+
76
+ # pyenv
77
+ # For a library or package, you might want to ignore these files since the code is
78
+ # intended to run in multiple environments; otherwise, check them in:
79
+ # .python-version
80
+
81
+ # pipenv
82
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
83
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
84
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
85
+ # install all needed dependencies.
86
+ #Pipfile.lock
87
+
88
+ # UV
89
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
90
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
91
+ # commonly ignored for libraries.
92
+ #uv.lock
93
+
94
+ # poetry
95
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
96
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
97
+ # commonly ignored for libraries.
98
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
99
+ #poetry.lock
100
+
101
+ # pdm
102
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
103
+ #pdm.lock
104
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
105
+ # in version control.
106
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
107
+ .pdm.toml
108
+ .pdm-python
109
+ .pdm-build/
110
+
111
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
112
+ __pypackages__/
113
+
114
+
115
+ # SageMath parsed files
116
+ *.sage.py
117
+
118
+ # Environments
119
+ .env
120
+ .venv
121
+ env/
122
+ venv/
123
+ ENV/
124
+ env.bak/
125
+ venv.bak/
126
+
127
+ # Spyder project settings
128
+ .spyderproject
129
+ .spyproject
130
+
131
+ # Rope project settings
132
+ .ropeproject
133
+
134
+ # mkdocs documentation
135
+ /site
136
+
137
+ # mypy
138
+ .mypy_cache/
139
+ .dmypy.json
140
+ dmypy.json
141
+
142
+ # Pyre type checker
143
+ .pyre/
144
+
145
+ # pytype static type analyzer
146
+ .pytype/
147
+
148
+ # Cython debug symbols
149
+ cython_debug/
150
+
151
+ .idea/
152
+
153
+ # Ruff stuff:
154
+ .ruff_cache/
155
+
156
+ # PyPI configuration file
157
+ .pypirc
158
+
159
+ # Datasets, logs, plots, etc.
160
+ outputs/
161
+ models/*
162
+
163
+ *.arrow
164
+ *.png
165
+ *.pt
166
+ *.pdf
167
+ *.gif
README.md CHANGED
@@ -1,2 +1,128 @@
1
- # TempoPFN
2
- Official code release for the paper "TempoPFN: Synthetic Pre-training of Linear RNNs for Zero-shot Time Series Forecasting"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TempoPFN: Synthetic Pre-Training of Linear RNNs for Zero-Shot Time Series Forecasting
2
+
3
+ [![arXiv](https://img.shields.io/badge/arXiv-2510.25502-b31b1b.svg)](https://arxiv.org/abs/2510.25502)
4
+ [![License](https://img.shields.io/badge/License-Apache_2.0-green.svg)](https://github.com/automl/TempoPFN/blob/main/LICENSE)
5
+
6
+ ---
7
+
8
+ **TempoPFN** introduced in [TempoPFN: Synthetic Pre-Training of Linear RNNs for Zero-Shot Time Series Forecasting](https://arxiv.org/abs/2510.25502), is a univariate time series foundation model pretrained **entirely on synthetic data**. It delivers top-tier zero-shot forecasting accuracy while remaining fully reproducible and free from real-data leakage.
9
+
10
+ Built on a **Linear RNN (GatedDeltaProduct)** backbone, TempoPFN performs end-to-end forecasting without patching or windowing. Its design enables fully parallelizable training and inference while maintaining stable temporal state-tracking across long sequences.
11
+
12
+ This repository includes the [**pretrained 35M parameter model,**](https://www.dropbox.com/scl/fi/5vmjr7nx9wj9w1vl2giuv/checkpoint.pth?rlkey=qmk08ojp7wj0l6kpm8hzgbzju&st=dyr07d00&dl=0), all training and inference code, and the **complete synthetic data generation pipeline** used for pretraining.
13
+
14
+ ## ✨ Why TempoPFN?
15
+
16
+ * **High Performance, No Real Data:** Achieves top-tier competitive results on **GIFT-Eval, outperforming all existing synthetic-only approaches** and **surpassing the vast majority of models trained on real-world data**. This ensures full reproducibility and eliminates benchmark leakage.
17
+ * **Parallel and Efficient:** The linear recurrence design enables full-sequence parallelization. This gives us the best of both worlds: the linear efficiency of an RNN, but with the training parallelism of a Transformer.
18
+ * **Open and Reproducible:** Includes the full synthetic data pipeline, configurations, and scripts to reproduce training from scratch.
19
+ * **State-Tracking Stability:** The GatedDeltaProduct recurrence and *state-weaving* mechanism preserve temporal continuity and information flow across long horizons, improving robustness without non-linear recurrence.
20
+
21
+
22
+ ![TempoPFN Overview](https://iili.io/KlUjfcP.png)
23
+
24
+ ## ⚙️ Installation
25
+
26
+ ```bash
27
+ git clone https://github.com/automl/TempoPFN.git
28
+ cd TempoPFN
29
+ python -m venv venv && source venv/bin/activate
30
+
31
+ # 1. Install PyTorch first (see PyTorch website for your specific CUDA version)
32
+ # Example for CUDA 12.6:
33
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
34
+
35
+ # 2. Install TempoPFN and all other dependencies
36
+ pip install .
37
+ export PYTHONPATH=$PWD
38
+ ```
39
+
40
+ ## 🚀 Quick Start: Run the Demo
41
+
42
+
43
+
44
+ **Prerequisites:**
45
+ * You must have a **CUDA-capable GPU** with a matching PyTorch version installed.
46
+ * You have run `export PYTHONPATH=$PWD` from the repo's root directory (see Installation).
47
+
48
+ ### 1. Run the Quick Start Script
49
+
50
+ Run a demo forecast on a synthetic sine wave:
51
+ ```bash
52
+ python examples/quick_start_tempo_pfn.py
53
+ ```
54
+
55
+ ### 2. Run with a Local Checkpoint
56
+
57
+ If you have already downloaded the model (e.g., to `models/checkpoint.pth`), you can point the script to it:
58
+ ```bash
59
+ python examples/quick_start_tempo_pfn.py --checkpoint models/checkpoint.pth
60
+ ```
61
+
62
+ ### 3. Run the Notebook version
63
+
64
+ ```bash
65
+ jupyter notebook examples/quick_start_tempo_pfn.ipynb
66
+ ```
67
+
68
+ ### Hardware & Performance Tips
69
+
70
+ **GPU Required:** Inference requires a CUDA-capable GPU. Tested on NVIDIA A100/H100.
71
+
72
+ **Triton Caches:** To prevent slowdowns from writing caches to a network filesystem, route caches to a local directory (like `/tmp`) before running:
73
+ ```bash
74
+ LOCAL_CACHE_BASE="${TMPDIR:-/tmp}/tsf-$(date +%s)"
75
+ mkdir -p "${LOCAL_CACHE_BASE}/triton" "${LOCAL_CACHE_BASE}/torchinductor"
76
+ export TRITON_CACHE_DIR="${LOCAL_CACHE_BASE}/triton"
77
+ export TORCHINDUCTOR_CACHE_DIR="${LOCAL_CACHE_BASE}/torchinductor"
78
+
79
+ python examples/quick_start_tempo_pfn.py
80
+ ```
81
+
82
+ ## 🚂 Training
83
+
84
+
85
+ ### Single-GPU Training (for debugging)
86
+ ```bash
87
+ torchrun --standalone --nproc_per_node=1 src/training/trainer_dist.py --config ./configs/train.yaml
88
+ ```
89
+
90
+ ### Multi-GPU Training (Single-Node)
91
+
92
+ This example uses 8 GPUs. The training script uses PyTorch DistributedDataParallel (DDP).
93
+ ```bash
94
+ torchrun --standalone --nproc_per_node=8 src/training/trainer_dist.py --config ./configs/train.yaml
95
+ ```
96
+
97
+ ### Configuration
98
+
99
+ All training and model parameters are controlled via YAML files in `configs/` (architecture, optimizers, paths).
100
+
101
+ ## 💾 Synthetic Data Generation
102
+
103
+ A core contribution of this work is our open-source synthetic data pipeline, located in `src/synthetic_generation/`. It combines diverse generators with a powerful augmentation cascade.
104
+
105
+ **Generators Used:**
106
+
107
+ * **Adapted Priors:** ForecastPFN, KernelSynth, GaussianProcess (GP), and CauKer (Structural Causal Models).
108
+ * **Novel Priors:** SDE (a flexible regime-switching Ornstein-Uhlenbeck process), Sawtooth, StepFunction, Anomaly, Spikes, SineWave, and Audio-Inspired generators (Stochastic Rhythms, Financial Volatility, Network Topology, Multi-Scale Fractals).
109
+
110
+ You can easily generate your own data by instantiating a generator wrapper. See `examples/generate_synthetic_data.py` for a minimal script, or inspect the generator code in `src/synthetic_generation/`.
111
+
112
+ ## 🤝 License
113
+
114
+ This project is licensed under the Apache 2.0 License. See the LICENSE file for details. This permissive license allows for both academic and commercial use.
115
+
116
+ ## 📚 Citation
117
+
118
+ If you find TempoPFN useful in your research, please consider citing our paper:
119
+ ```bibtex
120
+ @misc{moroshan2025tempopfn,
121
+ title={TempoPFN: Synthetic Pre-Training of Linear RNNs for Zero-Shot Time Series Forecasting},
122
+ author={Vladyslav Moroshan and Julien Siems and Arber Zela and Timur Carstensen and Frank Hutter},
123
+ year={2025},
124
+ eprint={2510.25502},
125
+ archivePrefix={arXiv},
126
+ primaryClass={cs.LG}
127
+ }
128
+ ```
configs/example.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_data_path: null # Replace with the path to root of the training data directory with subdirectories for each generator (e.g. gp, kernel, etc.)
2
+ model_path: ./models # Path where the model will be saved
3
+ model_name: TempoPFN
4
+ continue_training: false
5
+ checkpoint_path: null # Replace with the path to the checkpoint file
6
+ seed: 2025
7
+ wandb: true # whether to log to wandb
8
+ wandb_project_name: TempoPFNTraining
9
+ wandb_entity: university-of-freiburg-2024
10
+ wandb_plots: false
11
+
12
+ batch_size: 40
13
+ num_training_iterations: 1000000 # 1M
14
+ validation_batch_size: 64
15
+ num_validation_batches: 1
16
+ num_workers: 4
17
+ gradient_accumulation_enabled: true
18
+ accumulation_steps: 5 # Number of batches to accumulate before updating (effective batch size = batch_size * accumulation_steps)
19
+ log_interval: 2048
20
+ save_every: 100000
21
+
22
+ generator_proportions:
23
+ forecast_pfn: 1.0
24
+ gp: 1.0
25
+ kernel: 1.0
26
+ sawtooth: 1.0
27
+ sinewave: 1.0
28
+ step: 1.0
29
+ anomaly: 1.0
30
+ spike: 1.0
31
+ cauker_univariate: 1.0
32
+ ou_process: 3.0
33
+ audio_financial_volatility: 0.1
34
+ audio_multi_scale_fractal: 0.1
35
+ audio_network_topology: 0.5
36
+ audio_stochastic_rhythm: 0.5
37
+ augmented_per_sample_2048: 2.0
38
+ augmented_temp_batch_2048: 2.0
39
+
40
+ # Learning Rate Scheduler Configuration
41
+ lr_scheduler: cosine # Options: "warmup_stable_decay", "cosine_with_warmup", "cosine_with_restarts", "cosine"
42
+
43
+ # Learning Rate Parameters
44
+ peak_lr: 0.0002 # 2e-4 - Peak learning rate
45
+ min_lr_ratio: 0.01 # Minimum LR as fraction of peak LR
46
+
47
+ # WSD Scheduler Specific Parameters
48
+ warmup_ratio: 0.003 # 0.3% of total steps for warmup
49
+ stable_ratio: 0.90 # 90% of total steps at stable learning rate
50
+ decay_type: cosine # Type of decay: "cosine" or "linear"
51
+
52
+ # Alternative Scheduler Parameters (if using different schedulers)
53
+ num_cycles: 0.5 # For cosine_with_warmup: 0.5 = half cosine wave
54
+ num_restart_cycles: 4 # For cosine_with_restarts: number of restart cycles
55
+
56
+ # Optimizer Configuration
57
+ weight_decay: 0.01 # Weight decay for AdamW
58
+ beta1: 0.9 # Adam beta1 parameter
59
+ beta2: 0.98 # Adam beta2 parameter (optimized for transformers)
60
+ optimizer_eps: 1e-6 # Adam epsilon
61
+
62
+ # Training Stability
63
+ gradient_clip_val: 100.0
64
+ scaler: custom_robust
65
+
66
+ gift_eval:
67
+ evaluate_on_gift_eval: false
68
+ max_context_length: 3072
69
+ create_plots: false
70
+ max_plots: 5
71
+ dataset_storage_path: null # Replace with the path to the dataset storage path
72
+
73
+ data_augmentation:
74
+ nan_augmentation: true
75
+ scaler_augmentation: false
76
+ length_shortening: true
77
+ nan_stats_path: ./data/nan_stats.json
78
+
79
+ augmentation_probabilities:
80
+ scaler_augmentation: 0.5
81
+
82
+ TimeSeriesModel:
83
+ # Core architecture
84
+ embed_size: 512
85
+ num_encoder_layers: 10
86
+
87
+ # Scaling and preprocessing
88
+ scaler: custom_robust
89
+ epsilon: 0.00001
90
+ scaler_clamp_value: null
91
+ handle_constants: false
92
+
93
+ # Time features
94
+ K_max: 25
95
+ time_feature_config:
96
+ use_enhanced_features: true
97
+ use_holiday_features: false
98
+ use_index_features: true
99
+ include_seasonality_info: true
100
+
101
+ drop_enc_allow: false
102
+ encoding_dropout: 0.0
103
+
104
+ # Encoder configuration
105
+ encoder_config:
106
+ attn_mode: chunk
107
+ num_heads: 4
108
+ expand_v: 1.0
109
+ use_gate: false
110
+ use_short_conv: true
111
+ conv_size: 16
112
+ allow_neg_eigval: true
113
+ use_forget_gate: true
114
+ num_householder: 4
115
+ weaving: true
116
+
117
+ loss_type: 'quantile'
118
+ quantiles: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
data/nan_stats.json ADDED
The diff for this file is too large to render. See raw diff
 
examples/generate_synthetic_data.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from src.data.containers import BatchTimeSeriesContainer
8
+ from src.data.utils import sample_future_length
9
+ from src.plotting.plot_multivariate_timeseries import plot_from_container
10
+ from src.synthetic_generation.anomalies.anomaly_generator_wrapper import (
11
+ AnomalyGeneratorWrapper,
12
+ )
13
+ from src.synthetic_generation.audio_generators.financial_volatility_wrapper import (
14
+ FinancialVolatilityAudioWrapper,
15
+ )
16
+ from src.synthetic_generation.audio_generators.multi_scale_fractal_wrapper import (
17
+ MultiScaleFractalAudioWrapper,
18
+ )
19
+ from src.synthetic_generation.audio_generators.network_topology_wrapper import (
20
+ NetworkTopologyAudioWrapper,
21
+ )
22
+ from src.synthetic_generation.audio_generators.stochastic_rhythm_wrapper import (
23
+ StochasticRhythmAudioWrapper,
24
+ )
25
+ from src.synthetic_generation.cauker.cauker_generator_wrapper import (
26
+ CauKerGeneratorWrapper,
27
+ )
28
+ from src.synthetic_generation.forecast_pfn_prior.forecast_pfn_generator_wrapper import (
29
+ ForecastPFNGeneratorWrapper,
30
+ )
31
+ from src.synthetic_generation.generator_params import (
32
+ AnomalyGeneratorParams,
33
+ CauKerGeneratorParams,
34
+ FinancialVolatilityAudioParams,
35
+ ForecastPFNGeneratorParams,
36
+ GPGeneratorParams,
37
+ KernelGeneratorParams,
38
+ MultiScaleFractalAudioParams,
39
+ NetworkTopologyAudioParams,
40
+ OrnsteinUhlenbeckProcessGeneratorParams,
41
+ SawToothGeneratorParams,
42
+ SineWaveGeneratorParams,
43
+ SpikesGeneratorParams,
44
+ StepGeneratorParams,
45
+ StochasticRhythmAudioParams,
46
+ )
47
+ from src.synthetic_generation.gp_prior.gp_generator_wrapper import (
48
+ GPGeneratorWrapper,
49
+ )
50
+ from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import (
51
+ KernelGeneratorWrapper,
52
+ )
53
+ from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import (
54
+ OrnsteinUhlenbeckProcessGeneratorWrapper,
55
+ )
56
+ from src.synthetic_generation.sawtooth.sawtooth_generator_wrapper import (
57
+ SawToothGeneratorWrapper,
58
+ )
59
+ from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (
60
+ SineWaveGeneratorWrapper,
61
+ )
62
+ from src.synthetic_generation.spikes.spikes_generator_wrapper import (
63
+ SpikesGeneratorWrapper,
64
+ )
65
+ from src.synthetic_generation.steps.step_generator_wrapper import (
66
+ StepGeneratorWrapper,
67
+ )
68
+
69
+ # Configure logging
70
+ logging.basicConfig(
71
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
72
+ )
73
+ logger = logging.getLogger(__name__)
74
+
75
+
76
+ def visualize_batch_sample(
77
+ generator,
78
+ batch_size: int = 8,
79
+ output_dir: str = "outputs/plots",
80
+ sample_idx: Optional[int] = None,
81
+ prefix: str = "",
82
+ seed: Optional[int] = None,
83
+ ) -> None:
84
+ """
85
+ Visualize a sample from a batch of synthetic multivariate time series from any generator.
86
+ Also plot artificial predictions for demonstration if requested.
87
+
88
+ Args:
89
+ generator: Any generator wrapper (LMC, Kernel, GP, etc.)
90
+ batch_size: Number of samples to generate in the batch
91
+ output_dir: Directory to save plots
92
+ sample_idx: Index of the sample to visualize
93
+ seed: Seed for the generator
94
+ """
95
+ os.makedirs(output_dir, exist_ok=True)
96
+
97
+ generator_name = generator.__class__.__name__
98
+ logger.info(f"[{generator_name}] Generating batch of size {batch_size}")
99
+
100
+ batch = generator.generate_batch(batch_size=batch_size, seed=seed)
101
+ values = torch.from_numpy(batch.values)
102
+ if values.ndim == 2:
103
+ values = values.unsqueeze(-1) # Add channel dimension: [batch_size, seq_len, 1]
104
+
105
+ future_length = sample_future_length(range="gift_eval")
106
+ # Slice along the time dimension (dimension 1)
107
+ history_values = values[:, :-future_length, :]
108
+ future_values = values[:, -future_length:, :]
109
+
110
+ batch = BatchTimeSeriesContainer(
111
+ history_values=history_values,
112
+ future_values=future_values,
113
+ start=batch.start,
114
+ frequency=batch.frequency,
115
+ )
116
+
117
+ logger.info(
118
+ f"[{generator_name}] Batch history values shape: {batch.history_values.shape}"
119
+ )
120
+ logger.info(
121
+ f"[{generator_name}] Batch future values shape: {batch.future_values.shape}"
122
+ )
123
+ logger.info(f"[{generator_name}] Batch start: {batch.start}")
124
+ logger.info(f"[{generator_name}] Batch frequency: {batch.frequency}")
125
+
126
+ if sample_idx is None:
127
+ for sample_idx in range(batch_size):
128
+ filename = f"{prefix}_{generator_name.lower().replace('generatorwrapper', '')}_sample_{sample_idx}.png"
129
+ output_file = os.path.join(output_dir, filename)
130
+ title = f"{prefix.capitalize()} {generator_name.replace('GeneratorWrapper', '')} Synthetic Time Series (Sample {sample_idx})"
131
+
132
+ plot_from_container(
133
+ batch=batch,
134
+ sample_idx=sample_idx,
135
+ output_file=output_file,
136
+ show=False,
137
+ title=title,
138
+ )
139
+ logger.info(
140
+ f"[{generator_name}] Saved plot for sample {sample_idx} to {output_file}"
141
+ )
142
+ logger.info("--------------------------------")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ # Configuration
147
+ batch_size = 2
148
+ total_length = 2048
149
+ output_dir = "outputs/plots"
150
+ global_seed = 2025
151
+
152
+ logger.info(f"Saving plots to {output_dir}")
153
+
154
+ kernel_params_univariate = KernelGeneratorParams(
155
+ global_seed=global_seed,
156
+ length=total_length,
157
+ )
158
+ kernel_gen_univariate = KernelGeneratorWrapper(kernel_params_univariate)
159
+
160
+ gp_params_univariate = GPGeneratorParams(
161
+ global_seed=global_seed,
162
+ length=total_length,
163
+ )
164
+ gp_gen_univariate = GPGeneratorWrapper(gp_params_univariate)
165
+
166
+ forecast_pfn_univariate_params = ForecastPFNGeneratorParams(
167
+ global_seed=global_seed,
168
+ length=total_length,
169
+ )
170
+ forecast_pfn_univariate_gen = ForecastPFNGeneratorWrapper(
171
+ forecast_pfn_univariate_params
172
+ )
173
+
174
+ sine_wave_params = SineWaveGeneratorParams(
175
+ global_seed=global_seed,
176
+ length=total_length,
177
+ )
178
+ sine_wave_univariate_gen = SineWaveGeneratorWrapper(sine_wave_params)
179
+
180
+ sawtooth_params = SawToothGeneratorParams(
181
+ global_seed=global_seed,
182
+ length=total_length,
183
+ )
184
+ sawtooth_univariate_gen = SawToothGeneratorWrapper(sawtooth_params)
185
+
186
+ step_params = params = StepGeneratorParams(
187
+ length=2048,
188
+ global_seed=42,
189
+ )
190
+ step_gen_univariate = StepGeneratorWrapper(step_params)
191
+
192
+ anomaly_params = AnomalyGeneratorParams(
193
+ global_seed=global_seed,
194
+ length=total_length,
195
+ )
196
+ anomaly_gen_univariate = AnomalyGeneratorWrapper(anomaly_params)
197
+
198
+ spikes_params = SpikesGeneratorParams(
199
+ global_seed=global_seed,
200
+ length=total_length,
201
+ )
202
+ spikes_gen_univariate = SpikesGeneratorWrapper(spikes_params)
203
+
204
+ cauker_params_multivariate = CauKerGeneratorParams(
205
+ global_seed=global_seed,
206
+ length=total_length,
207
+ num_channels=5,
208
+ )
209
+ cauker_gen_multivariate = CauKerGeneratorWrapper(cauker_params_multivariate)
210
+
211
+ ou_params = OrnsteinUhlenbeckProcessGeneratorParams(
212
+ global_seed=global_seed,
213
+ length=total_length,
214
+ )
215
+ ou_gen_univariate = OrnsteinUhlenbeckProcessGeneratorWrapper(ou_params)
216
+
217
+ stochastic_rhythm_params = StochasticRhythmAudioParams(
218
+ global_seed=global_seed,
219
+ length=total_length,
220
+ )
221
+ stochastic_rhythm_gen_univariate = StochasticRhythmAudioWrapper(
222
+ stochastic_rhythm_params
223
+ )
224
+
225
+ financial_volatility_params = FinancialVolatilityAudioParams(
226
+ global_seed=global_seed,
227
+ length=total_length,
228
+ )
229
+ financial_volatility_gen_univariate = FinancialVolatilityAudioWrapper(
230
+ financial_volatility_params
231
+ )
232
+
233
+ multi_scale_fractal_params = MultiScaleFractalAudioParams(
234
+ global_seed=global_seed,
235
+ length=total_length,
236
+ )
237
+ multi_scale_fractal_gen_univariate = MultiScaleFractalAudioWrapper(
238
+ multi_scale_fractal_params
239
+ )
240
+
241
+ network_topology_params = NetworkTopologyAudioParams(
242
+ global_seed=global_seed,
243
+ length=total_length,
244
+ )
245
+ network_topology_gen_univariate = NetworkTopologyAudioWrapper(
246
+ network_topology_params
247
+ )
248
+
249
+ # Visualize samples from all generators
250
+ visualize_batch_sample(
251
+ kernel_gen_univariate, batch_size=batch_size, output_dir=output_dir
252
+ )
253
+
254
+ visualize_batch_sample(
255
+ gp_gen_univariate, batch_size=batch_size, output_dir=output_dir
256
+ )
257
+
258
+ visualize_batch_sample(
259
+ forecast_pfn_univariate_gen, batch_size=batch_size, output_dir=output_dir
260
+ )
261
+
262
+ visualize_batch_sample(
263
+ sine_wave_univariate_gen,
264
+ batch_size=batch_size,
265
+ output_dir=output_dir,
266
+ )
267
+
268
+ visualize_batch_sample(
269
+ sawtooth_univariate_gen,
270
+ batch_size=batch_size,
271
+ output_dir=output_dir,
272
+ )
273
+
274
+ visualize_batch_sample(
275
+ step_gen_univariate,
276
+ batch_size=batch_size,
277
+ output_dir=output_dir,
278
+ )
279
+
280
+ visualize_batch_sample(
281
+ anomaly_gen_univariate,
282
+ batch_size=batch_size,
283
+ output_dir=output_dir,
284
+ )
285
+
286
+ visualize_batch_sample(
287
+ spikes_gen_univariate,
288
+ batch_size=batch_size,
289
+ output_dir=output_dir,
290
+ )
291
+
292
+ visualize_batch_sample(
293
+ cauker_gen_multivariate,
294
+ batch_size=batch_size,
295
+ output_dir=output_dir,
296
+ prefix="multivariate",
297
+ )
298
+
299
+ visualize_batch_sample(
300
+ ou_gen_univariate,
301
+ batch_size=batch_size,
302
+ output_dir=output_dir,
303
+ seed=global_seed,
304
+ )
305
+
306
+ visualize_batch_sample(
307
+ stochastic_rhythm_gen_univariate, batch_size=batch_size, output_dir=output_dir
308
+ )
309
+
310
+ visualize_batch_sample(
311
+ financial_volatility_gen_univariate,
312
+ batch_size=batch_size,
313
+ output_dir=output_dir,
314
+ )
315
+
316
+ visualize_batch_sample(
317
+ multi_scale_fractal_gen_univariate,
318
+ batch_size=batch_size,
319
+ output_dir=output_dir,
320
+ )
321
+
322
+ visualize_batch_sample(
323
+ network_topology_gen_univariate, batch_size=batch_size, output_dir=output_dir
324
+ )
examples/quick_start_tempo_pfn.ipynb ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "231c6227",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Quick Start: Univariate Quantile Forecasting (CUDA, bfloat16)\n",
9
+ "\n",
10
+ "This notebook demonstrates how to:\n",
11
+ "- Generate synthetic sine wave time series data\n",
12
+ "- Pack data into `BatchTimeSeriesContainer`\n",
13
+ "- Load a pretrained model (from Dropbox)\n",
14
+ "- Run inference with bfloat16 on CUDA\n",
15
+ "- Visualize predictions\n"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "id": "bb6c5424-1c63-4cb0-a818-45d4199914e5",
21
+ "metadata": {},
22
+ "source": [
23
+ "## 1) Setup"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "612a78e8",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "import os\n",
34
+ "import urllib.request\n",
35
+ "import torch\n",
36
+ "import numpy as np\n",
37
+ "from pathlib import Path\n",
38
+ "\n",
39
+ "# Ensure CUDA is available\n",
40
+ "if not torch.cuda.is_available():\n",
41
+ " raise RuntimeError(\"CUDA is required to run this demo. No CUDA device detected.\")\n",
42
+ "\n",
43
+ "device = torch.device(\"cuda:0\")\n",
44
+ "\n",
45
+ "# Resolve repository root to be robust to running from subdirectories (e.g., examples/)\n",
46
+ "repo_root = Path.cwd()\n",
47
+ "if not (repo_root / \"configs\").exists():\n",
48
+ " repo_root = repo_root.parent\n",
49
+ "\n",
50
+ "# Inline plotting\n",
51
+ "%matplotlib inline\n"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "markdown",
56
+ "id": "3facf37d-0a77-4222-8464-6e42182547f8",
57
+ "metadata": {},
58
+ "source": [
59
+ "## 2) Download checkpoint from Dropbox"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "16dcb883",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "DROPBOX_URL = \"https://www.dropbox.com/scl/fi/5vmjr7nx9wj9w1vl2giuv/checkpoint.pth?rlkey=qmk08ojp7wj0l6kpm8hzgbzju&st=dyr07d00&dl=1\"\n",
70
+ "CHECKPOINT_DIR = repo_root / \"models\"\n",
71
+ "CHECKPOINT_PATH = CHECKPOINT_DIR / \"checkpoint.pth\"\n",
72
+ "\n",
73
+ "CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
74
+ "if not CHECKPOINT_PATH.exists():\n",
75
+ " print(f\"Downloading checkpoint to {CHECKPOINT_PATH} ...\")\n",
76
+ " urllib.request.urlretrieve(DROPBOX_URL, str(CHECKPOINT_PATH))\n",
77
+ " print(\"Done.\")\n",
78
+ "else:\n",
79
+ " print(f\"Using existing checkpoint at {CHECKPOINT_PATH}\")\n"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "id": "9be77e34-0c7a-4056-822f-ed2e3e090c40",
85
+ "metadata": {},
86
+ "source": [
87
+ "## 3) Generate synthetic sine wave data"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "1127526c",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "from src.synthetic_generation.generator_params import SineWaveGeneratorParams\n",
98
+ "from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (\n",
99
+ " SineWaveGeneratorWrapper,\n",
100
+ ")\n",
101
+ "\n",
102
+ "batch_size = 3\n",
103
+ "total_length = 1024\n",
104
+ "seed = 2025\n",
105
+ "\n",
106
+ "sine_params = SineWaveGeneratorParams(global_seed=seed, length=total_length)\n",
107
+ "wrapper = SineWaveGeneratorWrapper(sine_params)\n",
108
+ "\n",
109
+ "batch = wrapper.generate_batch(batch_size=batch_size, seed=seed)\n",
110
+ "values = torch.from_numpy(batch.values).to(torch.float32)\n",
111
+ "if values.ndim == 2:\n",
112
+ " values = values.unsqueeze(-1) # [B, S, 1]\n",
113
+ "\n",
114
+ "future_length = 256\n",
115
+ "history_values = values[:, :-future_length, :]\n",
116
+ "future_values = values[:, -future_length:, :]\n",
117
+ "\n",
118
+ "print(\"History:\", history_values.shape, \"Future:\", future_values.shape)\n"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "id": "a8844488-e51c-4805-baa9-491bfc67e8ca",
124
+ "metadata": {},
125
+ "source": [
126
+ "## 4) Build BatchTimeSeriesContainer"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "id": "f3b4d361",
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "from src.data.containers import BatchTimeSeriesContainer\n",
137
+ "\n",
138
+ "container = BatchTimeSeriesContainer(\n",
139
+ " history_values=history_values.to(device),\n",
140
+ " future_values=future_values.to(device),\n",
141
+ " start=batch.start,\n",
142
+ " frequency=batch.frequency,\n",
143
+ ")\n",
144
+ "\n",
145
+ "container.batch_size, container.history_length, container.future_length\n"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "id": "b5e7e790-a9aa-49c2-9d45-2dc823036883",
151
+ "metadata": {},
152
+ "source": [
153
+ "## 5) Load model and run inference"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "id": "1dd4e0e4",
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "import yaml\n",
164
+ "from src.models.model import TimeSeriesModel\n",
165
+ "\n",
166
+ "with open(repo_root / \"configs/example.yaml\", \"r\") as f:\n",
167
+ " config = yaml.safe_load(f)\n",
168
+ "\n",
169
+ "model = TimeSeriesModel(**config[\"TimeSeriesModel\"]).to(device)\n",
170
+ "ckpt = torch.load(CHECKPOINT_PATH, map_location=device)\n",
171
+ "model.load_state_dict(ckpt[\"model_state_dict\"])\n",
172
+ "model.eval()\n",
173
+ "\n",
174
+ "# bfloat16 autocast on CUDA\n",
175
+ "with (\n",
176
+ " torch.no_grad(),\n",
177
+ " torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16, enabled=True),\n",
178
+ "):\n",
179
+ " output = model(container)\n",
180
+ "\n",
181
+ "preds = output[\"result\"].to(torch.float32)\n",
182
+ "if hasattr(model, \"scaler\") and \"scale_statistics\" in output:\n",
183
+ " preds = model.scaler.inverse_scale(preds, output[\"scale_statistics\"])\n",
184
+ "\n",
185
+ "preds.shape\n"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "markdown",
190
+ "id": "ba16120f-27c8-4462-91cb-c9b3e0630a9d",
191
+ "metadata": {},
192
+ "source": [
193
+ "## 6) Plot predictions"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "id": "9bf02a0b",
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "import matplotlib.pyplot as plt\n",
204
+ "\n",
205
+ "plt.set_loglevel('error') \n",
206
+ "\n",
207
+ "# preds: [B, P, N, Q] for quantiles (univariate -> N=1)\n",
208
+ "preds_np = preds.cpu().numpy()\n",
209
+ "\n",
210
+ "batch_size = preds_np.shape[0]\n",
211
+ "prediction_length = preds_np.shape[1]\n",
212
+ "num_quantiles = preds_np.shape[-1]\n",
213
+ "\n",
214
+ "for i in range(batch_size):\n",
215
+ " fig, ax = plt.subplots(figsize=(12, 4))\n",
216
+ "\n",
217
+ " history = container.history_values[i, :, 0].detach().cpu().numpy()\n",
218
+ " future = container.future_values[i, :, 0].detach().cpu().numpy()\n",
219
+ "\n",
220
+ " # Time axes\n",
221
+ " hist_t = np.arange(len(history))\n",
222
+ " fut_t = np.arange(len(history), len(history) + len(future))\n",
223
+ "\n",
224
+ " # Plot history and ground truth future\n",
225
+ " ax.plot(hist_t, history, label=\"History\", color=\"black\")\n",
226
+ " ax.plot(fut_t, future, label=\"Ground Truth\", color=\"blue\")\n",
227
+ "\n",
228
+ " # Plot quantiles\n",
229
+ " median_idx = num_quantiles // 2\n",
230
+ " ax.plot(\n",
231
+ " fut_t,\n",
232
+ " preds_np[i, :, 0, median_idx],\n",
233
+ " label=\"Prediction (Median)\",\n",
234
+ " color=\"orange\",\n",
235
+ " linestyle=\"--\",\n",
236
+ " )\n",
237
+ " if num_quantiles >= 3:\n",
238
+ " ax.fill_between(\n",
239
+ " fut_t,\n",
240
+ " preds_np[i, :, 0, 0],\n",
241
+ " preds_np[i, :, 0, -1],\n",
242
+ " color=\"orange\",\n",
243
+ " alpha=0.2,\n",
244
+ " label=\"Prediction Interval\",\n",
245
+ " )\n",
246
+ "\n",
247
+ " ax.axvline(x=len(history), color=\"k\", linestyle=\":\", alpha=0.7)\n",
248
+ " ax.set_xlabel(\"Time Steps\")\n",
249
+ " ax.set_ylabel(\"Value\")\n",
250
+ " ax.set_title(f\"Sample {i + 1}\")\n",
251
+ " ax.legend()\n",
252
+ " ax.grid(True, alpha=0.3)\n",
253
+ " plt.show()\n"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "id": "d88bb77b-b6be-4b00-a881-a4b556cce86f",
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": []
263
+ }
264
+ ],
265
+ "metadata": {
266
+ "kernelspec": {
267
+ "display_name": "Python 3 (ipykernel)",
268
+ "language": "python",
269
+ "name": "python3"
270
+ },
271
+ "language_info": {
272
+ "codemirror_mode": {
273
+ "name": "ipython",
274
+ "version": 3
275
+ },
276
+ "file_extension": ".py",
277
+ "mimetype": "text/x-python",
278
+ "name": "python",
279
+ "nbconvert_exporter": "python",
280
+ "pygments_lexer": "ipython3",
281
+ "version": "3.12.9"
282
+ }
283
+ },
284
+ "nbformat": 4,
285
+ "nbformat_minor": 5
286
+ }
examples/quick_start_tempo_pfn.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+
4
+ import torch
5
+
6
+ from examples.utils import (
7
+ download_checkpoint_if_needed,
8
+ load_model,
9
+ run_inference_and_plot,
10
+ )
11
+ from src.data.containers import BatchTimeSeriesContainer
12
+ from src.synthetic_generation.generator_params import SineWaveGeneratorParams
13
+ from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (
14
+ SineWaveGeneratorWrapper,
15
+ )
16
+
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def main():
25
+ """Main execution function."""
26
+ # CLI
27
+ parser = argparse.ArgumentParser(description="Quick start demo for TimeSeriesModel")
28
+ parser.add_argument(
29
+ "--config",
30
+ default="configs/example.yaml",
31
+ help="Path to model config YAML (default: configs/example.yaml)",
32
+ )
33
+ parser.add_argument(
34
+ "--checkpoint",
35
+ default=None,
36
+ help="Path to model checkpoint. If omitted, downloads from Dropbox.",
37
+ )
38
+ parser.add_argument("--batch_size", type=int, default=3)
39
+ parser.add_argument("--total_length", type=int, default=2048)
40
+ parser.add_argument("--seed", type=int, default=42)
41
+ parser.add_argument("--output_dir", default="outputs")
42
+ args = parser.parse_args()
43
+
44
+ # Configuration
45
+ batch_size = args.batch_size
46
+ total_length = args.total_length
47
+ output_dir = args.output_dir
48
+ seed = args.seed
49
+
50
+ config_path = args.config
51
+ if args.checkpoint:
52
+ model_path = args.checkpoint
53
+ else:
54
+ dropbox_url = "https://www.dropbox.com/scl/fi/5vmjr7nx9wj9w1vl2giuv/checkpoint.pth?rlkey=qmk08ojp7wj0l6kpm8hzgbzju&st=dyr07d00&dl=0"
55
+ model_path = download_checkpoint_if_needed(dropbox_url, target_dir="models")
56
+
57
+ logger.info("=== Time Series Model Demo (Univariate Quantile) ===")
58
+
59
+ # 1) Generate synthetic sine wave data
60
+ sine_params = SineWaveGeneratorParams(global_seed=seed, length=total_length)
61
+ sine_generator = SineWaveGeneratorWrapper(sine_params)
62
+ batch = sine_generator.generate_batch(batch_size=batch_size, seed=seed)
63
+ values = torch.from_numpy(batch.values).to(torch.float32)
64
+ if values.ndim == 2:
65
+ values = values.unsqueeze(-1) # Ensure [B, S, 1] for univariate
66
+ future_length = 256
67
+ history_values = values[:, :-future_length, :]
68
+ future_values = values[:, -future_length:, :]
69
+
70
+ # 2) Load the pretrained model (CUDA-only). This demo requires a CUDA GPU.
71
+ if not torch.cuda.is_available():
72
+ raise RuntimeError(
73
+ "CUDA is required to run this demo. No CUDA device detected."
74
+ )
75
+ device = torch.device("cuda:0")
76
+ model = load_model(config_path=config_path, model_path=model_path, device=device)
77
+
78
+ # 3) Pack tensors into the model's input container
79
+ container = BatchTimeSeriesContainer(
80
+ history_values=history_values.to(device),
81
+ future_values=future_values.to(device),
82
+ start=batch.start,
83
+ frequency=batch.frequency,
84
+ )
85
+
86
+ # 4) Run inference (bfloat16 on CUDA) and plot results
87
+ run_inference_and_plot(
88
+ model=model, container=container, output_dir=output_dir, use_bfloat16=True
89
+ )
90
+
91
+ logger.info("=== Demo completed successfully! ===")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()
examples/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import urllib.request
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+
10
+ from src.data.containers import BatchTimeSeriesContainer
11
+ from src.models.model import TimeSeriesModel
12
+ from src.plotting.plot_timeseries import plot_from_container
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def load_model(
18
+ config_path: str, model_path: str, device: torch.device
19
+ ) -> TimeSeriesModel:
20
+ """Load the TimeSeriesModel from config and checkpoint."""
21
+ with open(config_path, "r") as f:
22
+ config = yaml.safe_load(f)
23
+
24
+ model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
25
+ checkpoint = torch.load(model_path, map_location=device)
26
+ model.load_state_dict(checkpoint["model_state_dict"])
27
+ model.eval()
28
+ logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}")
29
+ return model
30
+
31
+
32
+ def download_checkpoint_if_needed(url: str, target_dir: str = "models") -> str:
33
+ """Download checkpoint from URL into target_dir if not present and return its path.
34
+
35
+ Ensures direct download for Dropbox links by forcing dl=1.
36
+ """
37
+ os.makedirs(target_dir, exist_ok=True)
38
+ target_path = os.path.join(target_dir, "checkpoint.pth")
39
+
40
+ # Normalize Dropbox URL to force direct download
41
+ if "dropbox.com" in url and "dl=0" in url:
42
+ url = url.replace("dl=0", "dl=1")
43
+
44
+ if not os.path.exists(target_path):
45
+ logger.info(f"Downloading checkpoint from {url} to {target_path}...")
46
+ urllib.request.urlretrieve(url, target_path)
47
+ logger.info("Checkpoint downloaded successfully.")
48
+ else:
49
+ logger.info(f"Using existing checkpoint at {target_path}")
50
+
51
+ return target_path
52
+
53
+
54
+ def plot_with_library(
55
+ container: BatchTimeSeriesContainer,
56
+ predictions_np: np.ndarray, # [B, P, N, Q]
57
+ model_quantiles: List[float] | None,
58
+ output_dir: str = "outputs",
59
+ show_plots: bool = True,
60
+ save_plots: bool = True,
61
+ ):
62
+ os.makedirs(output_dir, exist_ok=True)
63
+ batch_size = container.batch_size
64
+ for i in range(batch_size):
65
+ output_file = (
66
+ os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png")
67
+ if save_plots
68
+ else None
69
+ )
70
+ plot_from_container(
71
+ batch=container,
72
+ sample_idx=i,
73
+ predicted_values=predictions_np,
74
+ model_quantiles=model_quantiles,
75
+ title=f"Sine Wave Time Series Prediction - Sample {i + 1}",
76
+ output_file=output_file,
77
+ show=show_plots,
78
+ )
79
+
80
+
81
+ def run_inference_and_plot(
82
+ model: TimeSeriesModel,
83
+ container: BatchTimeSeriesContainer,
84
+ output_dir: str = "outputs",
85
+ use_bfloat16: bool = True,
86
+ ) -> None:
87
+ """Run model inference with optional bfloat16 and plot using shared utilities."""
88
+ device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu"
89
+ autocast_enabled = use_bfloat16 and device_type == "cuda"
90
+ with (
91
+ torch.no_grad(),
92
+ torch.autocast(
93
+ device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled
94
+ ),
95
+ ):
96
+ model_output = model(container)
97
+
98
+ preds_full = model_output["result"].to(torch.float32)
99
+ if hasattr(model, "scaler") and "scale_statistics" in model_output:
100
+ preds_full = model.scaler.inverse_scale(
101
+ preds_full, model_output["scale_statistics"]
102
+ )
103
+
104
+ preds_np = preds_full.detach().cpu().numpy()
105
+ model_quantiles = (
106
+ model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
107
+ )
108
+ plot_with_library(
109
+ container=container,
110
+ predictions_np=preds_np,
111
+ model_quantiles=model_quantiles,
112
+ output_dir=output_dir,
113
+ show_plots=True,
114
+ save_plots=True,
115
+ )
gift_eval/submission/all_results.csv ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset,model,eval_metrics/MSE[mean],eval_metrics/MSE[0.5],eval_metrics/MAE[0.5],eval_metrics/MASE[0.5],eval_metrics/MAPE[0.5],eval_metrics/sMAPE[0.5],eval_metrics/MSIS,eval_metrics/RMSE[mean],eval_metrics/NRMSE[mean],eval_metrics/ND[0.5],eval_metrics/mean_weighted_sum_quantile_loss,domain,num_variates
2
+ bitbrains_fast_storage/5T/long,TempoPFN,3457994.418690934,3457994.418690934,327.16564021921977,0.9615742426354904,4.317236176373819,0.8336409912081381,19.490867974425026,1859.5683420328853,4.914283864546929,0.8646010960833215,0.6536844929905335,Web/CloudOps,2
3
+ bitbrains_fast_storage/5T/medium,TempoPFN,2800287.942594441,2800287.942594441,268.74003348343393,1.0930763358543272,3.2742632396336617,0.824830859725203,26.62443937663418,1673.406090162947,5.084269857453093,0.8165064414207498,0.6178037781792698,Web/CloudOps,2
4
+ bitbrains_fast_storage/5T/short,TempoPFN,1690352.428591824,1690352.428591824,162.1872711788372,0.8136291545842834,1.776193812092158,0.7668176303900445,20.32296781949572,1300.13554239234,4.082104033944751,0.5092279168950156,0.4037249136258121,Web/CloudOps,2
5
+ bitbrains_fast_storage/H/short,TempoPFN,2588949.8055826887,2588949.8055826887,300.81926572581165,1.1822543303101407,3.726231648106852,0.593958917999151,29.76405600888225,1609.0213813317362,4.586233656611166,0.8574326339199428,0.6458078536940699,Web/CloudOps,2
6
+ bitbrains_rnd/5T/long,TempoPFN,2990910.5081971884,2990910.5081971884,210.74577292406275,3.5186544815564527,2.529824679515041,0.7193304407946397,123.74784003341853,1729.4249067817857,6.625037355191929,0.8073195965870753,0.7073195661600815,Web/CloudOps,2
7
+ bitbrains_rnd/5T/medium,TempoPFN,2289257.9838845544,2289257.9838845544,152.20154125488185,4.54810611155951,0.5130282516077612,0.7426683109827628,167.72825817863563,1513.0294061532825,6.366085897951346,0.6403894606996178,0.6193602612637874,Web/CloudOps,2
8
+ bitbrains_rnd/5T/short,TempoPFN,2014077.5501842,2014077.5501842,135.98742661578987,1.8517397044034425,0.696232985412152,0.6817326560343352,63.2175943612252,1419.182000373525,5.806883078969316,0.5564212950558665,0.5257236486996486,Web/CloudOps,2
9
+ bitbrains_rnd/H/short,TempoPFN,2035943.7507352629,2035943.7507352629,166.20220717027726,6.030992589065415,1.762054722799612,0.5701899273183543,206.5459296074726,1426.865007888014,6.244659394585995,0.7273821760777641,0.6502124112431411,Web/CloudOps,2
10
+ bizitobs_application/10S/long,TempoPFN,5740816.497777778,5740816.497777778,1423.993611111111,3.481128281924532,0.0486669582790798,0.0496452288049222,25.03746966635237,2396.0001038768296,0.0924041469204081,0.0549177417154226,0.0475831149960796,Web/CloudOps,2
11
+ bizitobs_application/10S/medium,TempoPFN,2146376.1066666665,2146376.1066666665,832.4885416666667,2.154386004052589,0.0297427558898925,0.0299385500168018,15.690585782887144,1465.051571333469,0.0575282633018517,0.0326893748710731,0.0238959442034563,Web/CloudOps,2
12
+ bizitobs_application/10S/short,TempoPFN,376800.9244444444,376800.9244444444,333.5807638888889,1.045550170879607,0.0145058769649929,0.0146063554228645,8.24955441258388,613.8411231291403,0.0236733985259999,0.0128648767027747,0.0102197624925825,Web/CloudOps,2
13
+ bizitobs_l2c/5T/long,TempoPFN,332.849126984127,332.849126984127,12.822042410714284,1.22095704272019,0.9288389476299502,1.153716245249532,10.349816564139786,18.24415322738019,1.0243819286593263,0.7199385123737686,0.573918947104382,Web/CloudOps,7
14
+ bizitobs_l2c/5T/medium,TempoPFN,146.4688988095238,146.4688988095238,8.26147693452381,0.7818771866617815,0.5654391202343231,0.8614376093159631,7.573703014323006,12.10243359037858,0.636765980847373,0.4346751770356495,0.3544980423033946,Web/CloudOps,7
15
+ bizitobs_l2c/5T/short,TempoPFN,20.892621721540177,20.892621721540177,2.659001813616072,0.274018680480048,0.1498092349119431,0.2086212109555511,2.4052935303554963,4.570844749227453,0.1569813038038852,0.0913208814604556,0.0723605673125024,Web/CloudOps,7
16
+ bizitobs_l2c/H/long,TempoPFN,108.49439484126984,108.49439484126984,6.4422584170386905,0.6815875823888798,0.673632736325078,0.8544856513482197,5.284687159594512,10.416064268295862,0.6362432681202789,0.3935117376155234,0.3131677794751992,Web/CloudOps,7
17
+ bizitobs_l2c/H/medium,TempoPFN,85.11841517857142,85.11841517857142,5.809039015997024,0.6083751129472635,0.6200869245381406,0.8834569194800722,3.7736867069280655,9.225964186933062,0.55864614495367,0.3517461358421966,0.2719450016284953,Web/CloudOps,7
18
+ bizitobs_l2c/H/short,TempoPFN,58.60240730406746,58.60240730406746,4.6713135734437,0.4630005137848145,0.3948283755807854,0.6419900636889204,2.951036267389131,7.6552209180445905,0.4126360128422088,0.251795765049438,0.1985720701597805,Web/CloudOps,7
19
+ bizitobs_service/10S/long,TempoPFN,137172.65608465608,137172.65608465608,82.30403439153439,1.662345455331064,0.1011724369109623,0.0943146465267703,24.580062468505552,370.36827089352033,0.2743806554539279,0.0609734598710921,0.0557345053106378,Web/CloudOps,2
20
+ bizitobs_service/10S/medium,TempoPFN,7740.426666666666,7740.426666666666,30.32351686507937,1.4210776938224114,0.103039076063368,0.0716119877124027,22.67250742405877,87.9796946270369,0.0662155914274661,0.0228221933696611,0.0187131317044051,Web/CloudOps,2
21
+ bizitobs_service/10S/short,TempoPFN,2474.4792460317462,2474.4792460317462,19.176820798197756,1.1598653331648996,0.0666156292466259,0.0541983465988417,24.57783862628469,49.7441378056927,0.0368520491464996,0.0142068025239332,0.0115698518326863,Web/CloudOps,2
22
+ car_parts/M/short,TempoPFN,1.3818172451286372,1.3818172451286372,0.4672690957004653,0.8407515421652917,0.8293590893689794,1.893282740388816,25.816739632958576,1.175507228871281,2.8187457507850047,1.1204633588204531,1.016162704205562,Sales,1
23
+ covid_deaths/D/short,TempoPFN,290323.0616541353,290323.0616541353,95.88298186873432,36.82875379020991,0.0973305795059633,0.3802577731558227,838.2070203544016,538.8163524375772,0.2027322652277888,0.0360764368473681,0.0296493204600527,Healthcare,1
24
+ electricity/15T/long,TempoPFN,679367.7835053491,679367.7835053491,90.37124592952328,1.3619071064235215,0.2042962727803589,0.1999499070070441,10.487232652692931,824.2376984252473,1.3003359982860876,0.1425717174992594,0.1137251456887457,Energy,1
25
+ electricity/15T/medium,TempoPFN,410301.82361416105,410301.82361416105,74.23413740058417,1.1823607410007584,0.1802587787943053,0.1831106137567377,9.124865260310653,640.5480650303778,1.1037288021425813,0.1279128921376865,0.1022478853910184,Energy,1
26
+ electricity/15T/short,TempoPFN,197999.487064066,197999.487064066,69.58886888658678,1.2890658846498282,0.1882592086631687,0.2060175097324122,12.440409392978095,444.97133285647294,0.8771262285479021,0.1371733808638283,0.1124443744530128,Energy,1
27
+ electricity/D/short,TempoPFN,2460923936.675748,2460923936.675748,5265.180427927928,1.564314345775582,0.5004102473254829,0.1020489007479963,12.14133452277077,49607.70037681396,0.8152193034530106,0.0865244042438215,0.0664914998664846,Energy,1
28
+ electricity/H/long,TempoPFN,9660838.018114114,9660838.018114114,315.4014256756757,1.3657739992140687,0.3528330126922677,0.1644909658307117,11.137350730901831,3108.1888646145867,1.225868818973913,0.1243942340819344,0.1000007258954307,Energy,1
29
+ electricity/H/medium,TempoPFN,7117041.892843468,7117041.892843468,278.22569995777025,1.2130144034395085,0.2467986882167905,0.1500124404017645,9.945132926951258,2667.778456477125,1.0413218399707818,0.1086006587630108,0.0872520518606711,Energy,1
30
+ electricity/H/short,TempoPFN,2001828.2401337277,2001828.2401337277,199.38613109691724,1.0786172149769229,0.2218964961368299,0.1440280778205061,11.294129947761776,1414.8597952213242,0.6679148696019,0.0941244936083527,0.0771797324638773,Energy,1
31
+ electricity/W/short,TempoPFN,76183092615.61081,76183092615.61081,34727.13657094594,1.570365713856254,0.1694176463278896,0.0993488831601272,11.91788945218548,276012.8486422521,0.6297702574229203,0.0792358683497086,0.0606708906232695,Energy,1
32
+ ett1/15T/long,TempoPFN,11.819432043650794,11.819432043650794,1.9247003658234128,1.2507084318317785,0.7671729308733772,0.4743360534288272,8.737143903915339,3.437940087268944,0.6552747798713814,0.3668497927592424,0.288230653138245,Energy,7
33
+ ett1/15T/medium,TempoPFN,11.408509540860615,11.408509540860615,1.8437986488947795,1.1818547993600037,0.7281965758824728,0.4556773517772796,8.31158410608761,3.377648522398477,0.6437831670148321,0.3514298559046792,0.277877086121562,Energy,7
34
+ ett1/15T/short,TempoPFN,6.0998437790643605,6.0998437790643605,1.2401007697695778,0.7754559371204721,0.4745998734729215,0.2709062529464219,5.124313668126548,2.469786180839216,0.466892406241743,0.2344306713155297,0.1816328117134165,Energy,7
35
+ ett1/D/short,TempoPFN,40213.09841269841,40213.09841269841,133.03087797619048,1.6324476425289194,1.3032025049603175,0.4616568808156475,10.165232559778676,200.532038369679,0.5267040504936318,0.3494100136837688,0.2686181176947919,Energy,7
36
+ ett1/H/long,TempoPFN,164.89584986772488,164.89584986772488,7.56153480489418,1.3859274542018474,6981196483236.346,0.4785638067874847,10.00708086784301,12.841177900322265,0.6149361975860038,0.362105524659008,0.2816536526299331,Energy,7
37
+ ett1/H/medium,TempoPFN,158.99996279761905,158.99996279761905,7.330310639880952,1.3479604957613165,4318128927434.1104,0.442596366665559,9.590349301595682,12.609518737748044,0.6039660567793922,0.3511045032102148,0.2758386706555789,Energy,7
38
+ ett1/H/short,TempoPFN,104.70541410900296,104.70541410900296,5.155340358189174,0.8547778511929055,0.5003422782534644,0.2702135929272071,5.451980617457977,10.23256635009043,0.4775957088062226,0.2406208128310768,0.1863945925056479,Energy,7
39
+ ett1/W/short,TempoPFN,1183244.5,1183244.5,833.8333565848214,1.5044807548071886,0.5491453238895961,0.5423834221847527,8.346203452067822,1087.770426146988,0.4328498151953185,0.3318021943103104,0.2579454864080913,Energy,7
40
+ ett2/15T/long,TempoPFN,13.920968501984127,13.920968501984127,2.4372794208829367,1.0266229456075762,0.162421735656956,0.1912512110170139,7.240519836720884,3.731081411867627,0.1988360821855705,0.1298870321345517,0.1019580288018046,Energy,7
41
+ ett2/15T/medium,TempoPFN,12.987088216145834,12.987088216145834,2.3273520817832343,0.9702315616517476,0.1522595897167077,0.1778896720653481,6.742950364573618,3.6037602883857063,0.1920508963553819,0.1240287970544381,0.096317793490544,Energy,7
42
+ ett2/15T/short,TempoPFN,8.553860618954612,8.553860618954612,1.8131163824172247,0.7804718358045571,0.1129981226167575,0.1330301091075509,5.350648546766824,2.924698380851368,0.1388458481191596,0.0860750918808183,0.0685052637873467,Energy,7
43
+ ett2/D/short,TempoPFN,170239.08571428573,170239.08571428573,246.5058531746032,2.290326347664065,0.5867415170820932,0.181941641118808,19.77345328878349,412.6003947093189,0.2361420919863594,0.1410818036093808,0.1199711796314696,Energy,7
44
+ ett2/H/long,TempoPFN,218.90458002645505,218.90458002645505,8.95805431547619,0.9923323101716812,0.1756859628750919,0.1666816776540073,8.879245079301075,14.795424293559648,0.202611948300768,0.1226736592226554,0.1036227735159358,Energy,7
45
+ ett2/H/medium,TempoPFN,243.5093005952381,243.5093005952381,9.681060500372023,1.0446103612297537,0.1729210287876511,0.1670725712727777,7.986726475131007,15.60478454177558,0.2102845067618488,0.1304585158995543,0.1055031195534691,Energy,7
46
+ ett2/H/short,TempoPFN,108.53523995535714,108.53523995535714,6.435656592959449,0.7470950747417008,0.1237944578658914,0.1119298434719161,5.231412537431336,10.418024762658089,0.132129541137475,0.0816220321911712,0.0649374365945919,Energy,7
47
+ ett2/W/short,TempoPFN,3383120.0,3383120.0,1295.83203125,0.8951468014830303,0.1438319001879011,0.1607407323784273,9.470720058706704,1839.3259634985852,0.1542912497518785,0.1087004411060343,0.0918377144034967,Energy,7
48
+ hierarchical_sales/D/short,TempoPFN,29.123292892853453,29.123292892853453,2.330811869009017,0.7515128310580963,0.6410289944380209,1.0602027814719133,7.300803117167325,5.396600123490108,1.655617154145591,0.7150672692276139,0.5842758796416178,Sales,1
49
+ hierarchical_sales/W/short,TempoPFN,437.0786163847325,437.0786163847325,8.912233902236162,0.7289483130473005,0.5386128782220985,0.4609518352185516,6.64654723456629,20.906425241650773,0.9598242089373876,0.4091650177495076,0.3427086959502182,Sales,1
50
+ hospital/M/short,TempoPFN,3146.256471371143,3146.256471371143,18.09600954747936,0.768169214106615,0.1945593078361289,0.1746747839340947,5.195230182484194,56.09150088356652,0.2036251734486402,0.0656927168071784,0.0524846811483977,Healthcare,1
51
+ jena_weather/10T/long,TempoPFN,1802.1015893683864,1802.1015893683864,11.05077453290344,0.6480390269860509,0.8421329927032432,0.6608396658167077,6.431036388936636,42.45116711432544,0.2597659641659084,0.0676215825489287,0.0568675921221853,Nature,21
52
+ jena_weather/10T/medium,TempoPFN,1753.326057449495,1753.326057449495,10.45327409136003,0.6205262365569862,0.800809384246282,0.6752596791783848,5.983639047813428,41.8727364456814,0.2569722819618993,0.0641515679471918,0.0526090633894538,Nature,21
53
+ jena_weather/10T/short,TempoPFN,967.4834108746242,967.4834108746242,6.240461833893307,0.29785829799349,0.3495092146199383,0.5688907620872866,2.7136900415583747,31.104395362627194,0.1931972955145112,0.0387610926048839,0.0316042761177553,Nature,21
54
+ jena_weather/D/short,TempoPFN,407.6343998015873,407.6343998015873,10.096466548859128,1.2762104692374598,0.7647620393136633,0.4792207233136718,8.470860420630778,20.18995789499293,0.1215765454839294,0.0607972305335396,0.0502315688594873,Nature,21
55
+ jena_weather/H/long,TempoPFN,1450.035582010582,1450.035582010582,12.834169353505292,1.14925431188062,3.6454664207292495,0.6876010515835866,8.49444712038228,38.07933274114164,0.2292797721719729,0.0772759188138019,0.0615505784712624,Nature,21
56
+ jena_weather/H/medium,TempoPFN,1601.863988095238,1601.863988095238,11.43108181423611,0.830294830457016,2.907100446024271,0.6853773994144349,6.671981430906583,40.02329306910213,0.2449096285447175,0.069948816959045,0.0562073146926605,Nature,21
57
+ jena_weather/H/short,TempoPFN,1161.0002492396357,1161.0002492396357,8.45698316633054,0.5343309068765919,1.2899701263024146,0.6088007276136064,4.785975199754376,34.073453732189165,0.2088680861946653,0.0518407644501047,0.0420641895205845,Nature,21
58
+ kdd_cup_2018/D/short,TempoPFN,2957.558623620254,2957.558623620254,21.449108262527844,1.2022252741099626,0.5317497971257803,0.4686441782744058,9.261473936192472,54.38344071149098,1.2182468123931351,0.4804827982128498,0.3781022537751671,Nature,1
59
+ kdd_cup_2018/H/long,TempoPFN,4099.393834111588,4099.393834111588,24.37712764418711,1.0361222519400666,1.067596580941062,0.6143228628291789,7.771131914708889,64.02650883900814,1.502320679809611,0.5719859420463782,0.4457075391120335,Nature,1
60
+ kdd_cup_2018/H/medium,TempoPFN,5187.082299587964,5187.082299587964,25.7997948872936,1.076371675935021,0.996384455388961,0.5780413204727044,8.857767467515062,72.02140167747338,1.5075383869258192,0.5400364372465801,0.4275271421385803,Nature,1
61
+ kdd_cup_2018/H/short,TempoPFN,4423.398133663605,4423.398133663605,22.99868414981208,0.975075924949626,0.940964873703854,0.5088272138094274,7.590540688579701,66.50863202369753,1.3921461175037204,0.4814041106039762,0.3836356598458785,Nature,1
62
+ loop_seattle/5T/long,TempoPFN,157.56740910947713,157.56740910947713,7.634268714704449,1.191248365481325,0.294509001597258,0.1619458107627965,11.066948051479228,12.552585753918478,0.2219322959299876,0.1349754398667968,0.1102272021373501,Transport,1
63
+ loop_seattle/5T/medium,TempoPFN,150.1369470378612,150.1369470378612,7.426822445518092,1.1582245306263297,0.2787350445364528,0.1604399422057661,10.01895938551111,12.253038277825675,0.2180067277247357,0.1321384314672467,0.1067471651057607,Transport,1
64
+ loop_seattle/5T/short,TempoPFN,56.20837674450948,56.20837674450948,4.151334465349668,0.6528856304300188,0.1189866291357144,0.0868257500989403,5.418227383070815,7.497224602778649,0.1285608499805451,0.0711862209945109,0.0566841136858997,Transport,1
65
+ loop_seattle/D/short,TempoPFN,19.32995043617776,19.32995043617776,2.972921541920633,0.8969007376581118,0.0550185047066986,0.0549326847443657,7.26916754640511,4.396583950771071,0.0785728296906184,0.0531300802196517,0.043017058136015,Transport,1
66
+ loop_seattle/H/long,TempoPFN,89.40608606381149,89.40608606381149,5.467254720555985,1.12318926581471,0.184156507997936,0.123748771373072,7.909106132499233,9.455479155696526,0.1689820440045555,0.0977071455143036,0.0757696938381596,Transport,1
67
+ loop_seattle/H/medium,TempoPFN,82.05160402154283,82.05160402154283,5.249757043545214,1.067239394992798,0.1469916750156965,0.1198292414396702,7.71532453928058,9.058234045416514,0.1601032741074667,0.0927888688596669,0.0730646137379731,Transport,1
68
+ loop_seattle/H/short,TempoPFN,79.93021155830753,79.93021155830753,4.978952884492776,1.0019999078644888,0.1463370876655166,0.1132828703040474,7.663887880517498,8.94036976630763,0.1582536394917618,0.0881325309159338,0.0697681163822325,Transport,1
69
+ m4_daily/D/short,TempoPFN,1107302.1502247455,1107302.1502247455,200.0984529723884,4.4308070498735,0.0375582349397479,0.0325003448180839,73.85257371713408,1052.284253528839,0.162548835061418,0.0309096808387832,0.0255332842361311,Econ/Fin,1
70
+ m4_hourly/H/short,TempoPFN,1900790.397793584,1900790.397793584,269.87303079307173,0.8295739315523254,0.1211194871128469,0.1069586238356769,5.763527621155884,1378.691552811427,0.1882218572798091,0.036843631182058,0.028529335870705,Econ/Fin,1
71
+ m4_monthly/M/short,TempoPFN,1888038.791054398,1888038.791054398,550.6457703179253,0.9239334080470228,0.157533313108815,0.1295610419100214,7.634234953833278,1374.059238553563,0.2855766798961609,0.1144430942087997,0.0913401022550916,Econ/Fin,1
72
+ m4_quarterly/Q/short,TempoPFN,1796236.6248541668,1796236.6248541668,559.6329735514323,1.1747941926308771,0.1171090269982814,0.1022490358727294,8.829971053307407,1340.2375255357413,0.2243321614232385,0.0936727051500237,0.0748259755560803,Econ/Fin,1
73
+ m4_weekly/W/short,TempoPFN,583423.8894364688,583423.8894364688,281.9014021319905,2.5068974606709102,0.0621606022892198,0.0619608453649078,34.89581976015151,763.8218964107201,0.1391565014546586,0.0513580627371238,0.0419254044427488,Econ/Fin,1
74
+ m4_yearly/A/short,TempoPFN,3788840.647688692,3788840.647688692,918.094150872091,3.421401346947248,0.1715710473134636,0.1539546917894726,30.445192708697025,1946.4944509781403,0.3121294070682495,0.1472206523889813,0.1196218198038522,Econ/Fin,1
75
+ m_dense/D/short,TempoPFN,9482.795555555556,9482.795555555556,45.650052083333335,0.6899203610078747,0.1267966234392374,0.0999539874368826,7.043961677013339,97.3796465158688,0.1685785724555744,0.0790269926834956,0.0660205335795425,Transport,1
76
+ m_dense/H/long,TempoPFN,75216.58469135803,75216.58469135803,131.98908950617283,1.1099402431673535,0.4539421329801232,0.3155182786007836,8.274772567214445,274.2564214222851,0.4747784079532156,0.2284926254705489,0.1835932985711755,Transport,1
77
+ m_dense/H/medium,TempoPFN,66764.08444444444,66764.08444444444,124.88760416666666,1.0367512901513452,0.4164061887564455,0.3037215745855758,7.849365647315136,258.3874695964269,0.4495615533094931,0.2172886533776566,0.1733313995334337,Transport,1
78
+ m_dense/H/short,TempoPFN,66010.75111111111,66010.75111111111,120.28962565104166,1.0160056830633992,0.4109345557059414,0.2858404757066923,8.963124852594527,256.9255750428733,0.4557220010023848,0.213363846290401,0.1721874459654968,Transport,1
79
+ restaurant/D/short,TempoPFN,141.10532587289566,141.10532587289566,7.168732057134789,0.6880883048170244,0.6529455322995288,0.3943859221952507,4.677912027140814,11.878776278425976,0.5466888071812649,0.3299216590546327,0.2581917858443854,Sales,1
80
+ saugeen/D/short,TempoPFN,1375.7590625,1375.7590625,14.092150065104166,3.1309744145587413,0.3216047922770182,0.3610168146588098,30.500399370336144,37.09122621995665,1.20139365482093,0.4564480982807553,0.3755489541578553,Nature,1
81
+ saugeen/M/short,TempoPFN,450.8665829613095,450.8665829613095,13.476651146298362,0.7895196884776998,0.3773554620288667,0.3776235164111268,4.73649049060767,21.23361916775634,0.6374969476769328,0.4046094969866529,0.3102559834358562,Nature,1
82
+ saugeen/W/short,TempoPFN,1229.458984375,1229.458984375,16.6321533203125,1.349179934912743,0.4265083312988281,0.4078161463980466,9.738107739351932,35.0636419154514,1.058425187524238,0.5020553780303745,0.4004677568480909,Nature,1
83
+ solar/10T/long,TempoPFN,38.03510999594485,38.03510999594485,2.95217954414411,1.289682582185996,1.980939740009238,1.613347773805571,10.516784658746062,6.167261142188228,1.3329548105705944,0.6380663692211938,0.5039360963008123,Energy,1
84
+ solar/10T/medium,TempoPFN,31.73545226443265,31.73545226443265,2.5354660658571637,1.1083795980449829,2.376593457700527,1.55737387319841,8.328948817077988,5.6334227841013895,1.2500789187217176,0.5626300030254013,0.4351566003689724,Energy,1
85
+ solar/10T/short,TempoPFN,23.310370868023877,23.310370868023877,2.089681713540479,0.9104617574218964,3.088664790072517,1.572870007326998,5.75677115505075,4.828081489372758,1.4044883551066518,0.6078881723531575,0.4620934487157311,Energy,1
86
+ solar/D/short,TempoPFN,128838.1800486618,128838.1800486618,250.6613594890511,0.9782785201693562,1.044265071376977,0.4281686406100531,5.697261783834341,358.94035723036467,0.5185465709334212,0.3621202960611567,0.2800056850496608,Energy,1
87
+ solar/H/long,TempoPFN,875.8664436334144,875.8664436334144,13.072592888280615,0.9956236178493916,4.22992235524107,1.4246637276180487,7.401661699548643,29.59504086216835,1.0261141373748248,0.4532506793040977,0.3659397066604597,Energy,1
88
+ solar/H/medium,TempoPFN,969.7141119221412,969.7141119221412,12.987529225593066,0.9916768572605807,4.270079450594054,1.4254524509643212,7.490816829317306,31.140233010081044,1.1217438746826656,0.4678411157474254,0.3867255305832896,Energy,1
89
+ solar/H/short,TempoPFN,687.7860367694727,687.7860367694727,11.829851353484486,0.9153026829231484,3.298143647835289,1.4320977852338796,6.105712982423255,26.22567514420692,0.9666871244547572,0.4360522627061222,0.338930893375135,Energy,1
90
+ solar/W/short,TempoPFN,1837604.379562044,1837604.379562044,1080.8457458941605,1.1818483501196966,0.2449082423300638,0.2096709326072068,7.4416251763844,1355.5826716073218,0.2767083080725079,0.2206276340779138,0.160040923371332,Energy,1
91
+ sz_taxi/15T/long,TempoPFN,17.335278668091167,17.335278668091167,2.8752947771990742,0.5435739715185243,8150111586328.506,0.4184713652630097,4.380632638022632,4.163565619525069,0.3853721873400869,0.266132142205346,0.2125054942462848,Transport,1
92
+ sz_taxi/15T/medium,TempoPFN,17.73923068576389,17.73923068576389,2.88089839576656,0.5629010694446301,12249100752073.3,0.4169614441536142,4.14479162506023,4.211796610208509,0.3919779681855994,0.2681156770449617,0.2116726820657642,Transport,1
93
+ sz_taxi/15T/short,TempoPFN,17.57054233201694,17.57054233201694,2.8061145880283456,0.5602273611005313,959958212480.042,0.4052746826217428,4.012449587804747,4.191723074347462,0.3919440456333012,0.2623837225491499,0.2066414499332957,Transport,1
94
+ sz_taxi/H/short,TempoPFN,7.529553437844301,7.529553437844301,1.8911764356825087,0.5763698315699721,1.1282697131491115,0.3028206124743695,4.102259224272375,2.744003177447924,0.2556208258093829,0.1761747530810148,0.1390065711740581,Transport,1
95
+ temperature_rain/D/short,TempoPFN,186.81455040605852,186.81455040605852,5.9566782469540085,1.385070112921999,28.13298940652906,1.5239891474214418,25.51129792837534,13.668011940514925,1.60908637363345,0.7012585181375972,0.572396651054853,Nature,1
96
+ us_births/D/short,TempoPFN,188935.12,188935.12,266.77953125,0.3928744382386066,0.0256256341934204,0.0252893901831033,3.346242220924745,434.6666768916154,0.040746473841947,0.0250084162636546,0.0202561474459704,Healthcare,1
97
+ us_births/M/short,TempoPFN,53081706.66666666,53081706.66666666,5884.473958333333,0.6657763993191178,0.0181355277697245,0.018380189816312,3.8133237509773754,7285.719365077594,0.0226293946745239,0.0182771359947467,0.0136913144696876,Healthcare,1
98
+ us_births/W/short,TempoPFN,1710668.857142857,1710668.857142857,1002.7292131696428,0.9115505789348626,0.0135757678321429,0.0135867246503487,5.849062097782842,1307.9254019793548,0.0177549590839956,0.0136119507467408,0.0106687511144078,Healthcare,1
gift_eval/submission/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "model": "TempoPFN",
3
+ "model_type": "pretrained",
4
+ "model_dtype": "float32",
5
+ "model_link": "https://github.com/automl/TempoPFN/tree/main"
6
+ }
pyproject.toml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "TempoPFN"
3
+ version = "0.1.0"
4
+ description = "Univariate Time Series Forecasting Using Linear RNNs"
5
+ authors = [
6
+ ]
7
+ readme = "README.md"
8
+ license = { file = "LICENSE" }
9
+ requires-python = ">=3.10,<3.13"
10
+
11
+ dependencies = [
12
+ "torch>=2.5.0",
13
+ "torchmetrics",
14
+ "triton==3.2.0",
15
+ "numpy",
16
+ "gpytorch",
17
+ "flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main",
18
+ "scikit-learn",
19
+ ]
20
+
21
+ classifiers = [
22
+ "Intended Audience :: Science/Research",
23
+ "Intended Audience :: Developers",
24
+ "License :: OSI Approved :: Apache Software License",
25
+ "Programming Language :: Python",
26
+ "Topic :: Software Development",
27
+ "Topic :: Scientific/Engineering",
28
+ "Operating System :: POSIX",
29
+ "Operating System :: Unix",
30
+ "Operating System :: MacOS",
31
+ "Programming Language :: Python :: 3.9",
32
+ "Programming Language :: Python :: 3.10",
33
+ "Programming Language :: Python :: 3.11",
34
+ ]
35
+
36
+ [project.optional-dependencies]
37
+ dev = [
38
+ "wandb",
39
+ "ujson",
40
+ "build",
41
+ "pre-commit",
42
+ "pytest>=8",
43
+ "ruff",
44
+ "mypy",
45
+ "commitizen",
46
+ "black",
47
+ "matplotlib",
48
+ "gluonts",
49
+ "pyo",
50
+ "statsmodels"
51
+ ]
52
+
53
+ [build-system]
54
+ requires = ["setuptools>=68.2.2", "wheel>=0.41.2"]
55
+ build-backend = "setuptools.build_meta"
56
+
57
+ package-dir = {"" = "src"}
src/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
File without changes
src/data/augmentations.py ADDED
@@ -0,0 +1,1318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from collections import Counter
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from joblib import Parallel, delayed
11
+ from torch.quasirandom import SobolEngine
12
+ import torch.nn.functional as F
13
+
14
+
15
+ from src.gift_eval.data import Dataset
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def find_consecutive_nan_lengths(series: np.ndarray) -> list[int]:
21
+ """Finds the lengths of all consecutive NaN blocks in a 1D array."""
22
+ if series.ndim > 1:
23
+ # For multivariate series, flatten to treat it as one long sequence
24
+ series = series.flatten()
25
+
26
+ is_nan = np.isnan(series)
27
+ padded_is_nan = np.concatenate(([False], is_nan, [False]))
28
+ diffs = np.diff(padded_is_nan.astype(int))
29
+
30
+ start_indices = np.where(diffs == 1)[0]
31
+ end_indices = np.where(diffs == -1)[0]
32
+
33
+ return (end_indices - start_indices).tolist()
34
+
35
+
36
+ def analyze_datasets_for_augmentation(gift_eval_path_str: str) -> dict:
37
+ """
38
+ Analyzes all datasets to derive statistics needed for NaN augmentation.
39
+ This version collects the full distribution of NaN ratios.
40
+ """
41
+ logger.info(
42
+ "--- Starting Dataset Analysis for Augmentation (Full Distribution) ---"
43
+ )
44
+ path = Path(gift_eval_path_str)
45
+ if not path.exists():
46
+ raise FileNotFoundError(
47
+ f"Provided raw data path for augmentation analysis does not exist: {gift_eval_path_str}"
48
+ )
49
+
50
+ dataset_names = []
51
+ for dataset_dir in path.iterdir():
52
+ if dataset_dir.name.startswith(".") or not dataset_dir.is_dir():
53
+ continue
54
+ freq_dirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
55
+ if freq_dirs:
56
+ for freq_dir in freq_dirs:
57
+ dataset_names.append(f"{dataset_dir.name}/{freq_dir.name}")
58
+ else:
59
+ dataset_names.append(dataset_dir.name)
60
+
61
+ total_series_count = 0
62
+ series_with_nans_count = 0
63
+ nan_ratio_distribution = []
64
+ all_consecutive_nan_lengths = Counter()
65
+
66
+ for ds_name in sorted(dataset_names):
67
+ try:
68
+ ds = Dataset(name=ds_name, term="short", to_univariate=False)
69
+ for series_data in ds.training_dataset:
70
+ total_series_count += 1
71
+ target = np.atleast_1d(series_data["target"])
72
+ num_nans = np.isnan(target).sum()
73
+
74
+ if num_nans > 0:
75
+ series_with_nans_count += 1
76
+ nan_ratio = num_nans / target.size
77
+ nan_ratio_distribution.append(float(nan_ratio))
78
+
79
+ nan_lengths = find_consecutive_nan_lengths(target)
80
+ all_consecutive_nan_lengths.update(nan_lengths)
81
+ except Exception as e:
82
+ logger.warning(
83
+ f"Could not process {ds_name} for augmentation analysis: {e}"
84
+ )
85
+
86
+ if total_series_count == 0:
87
+ raise ValueError(
88
+ "No series were found during augmentation analysis. Check dataset path."
89
+ )
90
+
91
+ p_series_has_nan = (
92
+ series_with_nans_count / total_series_count if total_series_count > 0 else 0
93
+ )
94
+
95
+ logger.info("--- Augmentation Analysis Complete ---")
96
+ # Print summary statistics
97
+ logger.info(f"Total series analyzed: {total_series_count}")
98
+ logger.info(f"Series with NaNs: {series_with_nans_count} ({p_series_has_nan:.4f})")
99
+ logger.info(f"NaN ratio distribution: {Counter(nan_ratio_distribution)}")
100
+ logger.info(f"Consecutive NaN lengths distribution: {all_consecutive_nan_lengths}")
101
+ logger.info("--- End of Dataset Analysis for Augmentation ---")
102
+ return {
103
+ "p_series_has_nan": p_series_has_nan,
104
+ "nan_ratio_distribution": nan_ratio_distribution,
105
+ "nan_length_distribution": all_consecutive_nan_lengths,
106
+ }
107
+
108
+
109
+ class NanAugmenter:
110
+ """
111
+ Applies realistic NaN augmentation by generating and caching NaN patterns on-demand
112
+ during the first transform call for a given data shape.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ p_series_has_nan: float,
118
+ nan_ratio_distribution: List[float],
119
+ nan_length_distribution: Counter,
120
+ num_patterns: int = 100000,
121
+ n_jobs: int = -1,
122
+ nan_patterns_path: Optional[str] = None,
123
+ ):
124
+ """
125
+ Initializes the augmenter. NaN patterns are not generated at this stage.
126
+
127
+ Args:
128
+ p_series_has_nan (float): Probability that a series in a batch will be augmented.
129
+ nan_ratio_distribution (List[float]): A list of NaN ratios observed in the dataset.
130
+ nan_length_distribution (Counter): A Counter of consecutive NaN block lengths.
131
+ num_patterns (int): The number of unique NaN patterns to generate per data shape.
132
+ n_jobs (int): The number of CPU cores to use for parallel pattern generation (-1 for all cores).
133
+ """
134
+ self.p_series_has_nan = p_series_has_nan
135
+ self.nan_ratio_distribution = nan_ratio_distribution
136
+ self.num_patterns = num_patterns
137
+ self.n_jobs = n_jobs
138
+ self.max_length = 2048
139
+ self.nan_patterns_path = nan_patterns_path
140
+ # Cache to store patterns: Dict[shape_tuple -> pattern_tensor]
141
+ self.pattern_cache: Dict[Tuple[int, ...], torch.BoolTensor] = {}
142
+
143
+ if not nan_length_distribution or sum(nan_length_distribution.values()) == 0:
144
+ self._has_block_distribution = False
145
+ logger.warning("NaN length distribution is empty. Augmentation disabled.")
146
+ else:
147
+ self._has_block_distribution = True
148
+ total_blocks = sum(nan_length_distribution.values())
149
+ self.dist_lengths = list(int(i) for i in nan_length_distribution.keys())
150
+ self.dist_probs = [
151
+ count / total_blocks for count in nan_length_distribution.values()
152
+ ]
153
+
154
+ if not self.nan_ratio_distribution:
155
+ logger.warning("NaN ratio distribution is empty. Augmentation disabled.")
156
+
157
+ # Try to load existing patterns from disk
158
+ self._load_existing_patterns()
159
+
160
+ def _load_existing_patterns(self):
161
+ """Load existing NaN patterns from disk if they exist."""
162
+ # Determine where to look for patterns
163
+ explicit_path: Optional[Path] = (
164
+ Path(self.nan_patterns_path).resolve()
165
+ if self.nan_patterns_path is not None
166
+ else None
167
+ )
168
+
169
+ candidate_files: List[Path] = []
170
+ if explicit_path is not None:
171
+ # If the explicit path exists, use it directly
172
+ if explicit_path.is_file():
173
+ candidate_files.append(explicit_path)
174
+ # Also search the directory of the explicit path for matching files
175
+ explicit_dir = explicit_path.parent
176
+ explicit_dir.mkdir(exist_ok=True, parents=True)
177
+ candidate_files.extend(
178
+ list(explicit_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))
179
+ )
180
+ else:
181
+ # Default to the ./data directory
182
+ data_dir = Path("data")
183
+ data_dir.mkdir(exist_ok=True)
184
+ candidate_files.extend(
185
+ list(data_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))
186
+ )
187
+
188
+ # De-duplicate candidate files while preserving order
189
+ seen: set[str] = set()
190
+ unique_candidates: List[Path] = []
191
+ for f in candidate_files:
192
+ key = str(f.resolve())
193
+ if key not in seen:
194
+ seen.add(key)
195
+ unique_candidates.append(f)
196
+
197
+ for pattern_file in unique_candidates:
198
+ try:
199
+ # Extract num_channels from filename
200
+ filename = pattern_file.stem
201
+ parts = filename.split("_")
202
+ if len(parts) >= 4:
203
+ num_channels = int(parts[-1])
204
+
205
+ # Load patterns
206
+ patterns = torch.load(pattern_file, map_location="cpu")
207
+ cache_key = (self.max_length, num_channels)
208
+ self.pattern_cache[cache_key] = patterns
209
+
210
+ logger.info(
211
+ f"Loaded {patterns.shape[0]} patterns for shape {cache_key} from {pattern_file}"
212
+ )
213
+ except (ValueError, RuntimeError, FileNotFoundError) as e:
214
+ logger.warning(f"Failed to load patterns from {pattern_file}: {e}")
215
+
216
+ def _get_pattern_file_path(self, num_channels: int) -> Path:
217
+ """Resolve the target file path for storing/loading patterns for a given channel count."""
218
+ # If user provided a file path, use its directory as the base directory
219
+ if self.nan_patterns_path is not None:
220
+ base_dir = Path(self.nan_patterns_path).resolve().parent
221
+ base_dir.mkdir(exist_ok=True, parents=True)
222
+ else:
223
+ base_dir = Path("data").resolve()
224
+ base_dir.mkdir(exist_ok=True, parents=True)
225
+
226
+ return base_dir / f"nan_patterns_{self.max_length}_{num_channels}.pt"
227
+
228
+ def _generate_nan_mask(self, series_shape: Tuple[int, ...]) -> np.ndarray:
229
+ """Generates a single boolean NaN mask for a given series shape."""
230
+ series_size = int(np.prod(series_shape))
231
+ sampled_ratio = np.random.choice(self.nan_ratio_distribution)
232
+ n_nans_to_add = int(round(series_size * sampled_ratio))
233
+
234
+ if n_nans_to_add == 0:
235
+ return np.zeros(series_shape, dtype=bool)
236
+
237
+ mask_flat = np.zeros(series_size, dtype=bool)
238
+ nans_added = 0
239
+ max_attempts = n_nans_to_add * 2
240
+ attempts = 0
241
+ while nans_added < n_nans_to_add and attempts < max_attempts:
242
+ attempts += 1
243
+ block_length = np.random.choice(self.dist_lengths, p=self.dist_probs)
244
+
245
+ if nans_added + block_length > n_nans_to_add:
246
+ block_length = n_nans_to_add - nans_added
247
+ if block_length <= 0:
248
+ break
249
+
250
+ nan_counts_in_window = np.convolve(
251
+ mask_flat, np.ones(block_length), mode="valid"
252
+ )
253
+ valid_starts = np.where(nan_counts_in_window == 0)[0]
254
+
255
+ if valid_starts.size == 0:
256
+ continue
257
+
258
+ start_pos = np.random.choice(valid_starts)
259
+ mask_flat[start_pos : start_pos + block_length] = True
260
+ nans_added += block_length
261
+
262
+ return mask_flat.reshape(series_shape)
263
+
264
+ def _pregenerate_patterns(self, series_shape: Tuple[int, ...]) -> torch.BoolTensor:
265
+ """Uses joblib to parallelize the generation of NaN masks for a given shape."""
266
+ if not self._has_block_distribution or not self.nan_ratio_distribution:
267
+ return torch.empty(0, *series_shape, dtype=torch.bool)
268
+
269
+ logger.info(
270
+ f"Generating {self.num_patterns} NaN patterns for shape {series_shape}..."
271
+ )
272
+
273
+ with Parallel(n_jobs=self.n_jobs, backend="loky") as parallel:
274
+ masks_list = parallel(
275
+ delayed(self._generate_nan_mask)(series_shape)
276
+ for _ in range(self.num_patterns)
277
+ )
278
+
279
+ logger.info(f"Pattern generation complete for shape {series_shape}.")
280
+ return torch.from_numpy(np.stack(masks_list)).bool()
281
+
282
+ def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor:
283
+ """
284
+ Applies NaN patterns to a batch, generating them on-demand if the shape is new.
285
+ """
286
+ if self.p_series_has_nan == 0:
287
+ return time_series_batch
288
+
289
+ history_length, num_channels = time_series_batch.shape[1:]
290
+ assert history_length <= self.max_length, (
291
+ f"History length {history_length} exceeds maximum allowed {self.max_length}."
292
+ )
293
+
294
+ # 1. Check cache and generate patterns if the shape is new
295
+ if (
296
+ self.max_length,
297
+ num_channels,
298
+ ) not in self.pattern_cache:
299
+ # Try loading from a resolved file path if available
300
+ target_file = self._get_pattern_file_path(num_channels)
301
+ if target_file.exists():
302
+ try:
303
+ patterns = torch.load(target_file, map_location="cpu")
304
+ self.pattern_cache[(self.max_length, num_channels)] = patterns
305
+ logger.info(
306
+ f"Loaded NaN patterns from {target_file} for shape {(self.max_length, num_channels)}"
307
+ )
308
+ except (RuntimeError, FileNotFoundError):
309
+ # Fall back to generating if loading fails
310
+ patterns = self._pregenerate_patterns(
311
+ (self.max_length, num_channels)
312
+ )
313
+ torch.save(patterns, target_file)
314
+ self.pattern_cache[(self.max_length, num_channels)] = patterns
315
+ logger.info(
316
+ f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}"
317
+ )
318
+ else:
319
+ patterns = self._pregenerate_patterns((self.max_length, num_channels))
320
+ torch.save(patterns, target_file)
321
+ self.pattern_cache[(self.max_length, num_channels)] = patterns
322
+ logger.info(
323
+ f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}"
324
+ )
325
+ patterns = self.pattern_cache[(self.max_length, num_channels)][
326
+ :, :history_length, :
327
+ ]
328
+
329
+ # Early exit if patterns are empty (e.g., generation failed or was disabled)
330
+ if patterns.numel() == 0:
331
+ return time_series_batch
332
+
333
+ batch_size = time_series_batch.shape[0]
334
+ device = time_series_batch.device
335
+
336
+ # 2. Vectorized decision on which series to augment
337
+ augment_mask = torch.rand(batch_size, device=device) < self.p_series_has_nan
338
+ indices_to_augment = torch.where(augment_mask)[0]
339
+ num_to_augment = indices_to_augment.numel()
340
+
341
+ if num_to_augment == 0:
342
+ return time_series_batch
343
+
344
+ # 3. Randomly sample patterns for each series being augmented
345
+ pattern_indices = torch.randint(
346
+ 0, patterns.shape[0], (num_to_augment,), device=device
347
+ )
348
+ # 4. Select patterns and apply them in a single vectorized operation
349
+ selected_patterns = patterns[pattern_indices].to(device)
350
+
351
+ time_series_batch[indices_to_augment] = time_series_batch[
352
+ indices_to_augment
353
+ ].masked_fill(selected_patterns, float("nan"))
354
+
355
+ return time_series_batch
356
+
357
+
358
+ class CensorAugmenter:
359
+ """
360
+ Applies censor augmentation by clipping values from above, below, or both.
361
+ """
362
+
363
+ def __init__(self):
364
+ """Initializes the CensorAugmenter."""
365
+ pass
366
+
367
+ def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor:
368
+ """
369
+ Applies a vectorized censor augmentation to a batch of time series.
370
+ """
371
+ batch_size, seq_len, num_channels = time_series_batch.shape
372
+ assert num_channels == 1
373
+ time_series_batch = time_series_batch.squeeze(-1)
374
+ with torch.no_grad():
375
+ batch_size, seq_len = time_series_batch.shape
376
+ device = time_series_batch.device
377
+
378
+ # Step 1: Choose an op mode for each series
379
+ op_mode = torch.randint(0, 3, (batch_size, 1), device=device)
380
+
381
+ # Step 2: Calculate potential thresholds for all series
382
+ q1 = torch.rand(batch_size, device=device)
383
+ q2 = torch.rand(batch_size, device=device)
384
+ q_low = torch.minimum(q1, q2)
385
+ q_high = torch.maximum(q1, q2)
386
+
387
+ sorted_series = torch.sort(time_series_batch, dim=1).values
388
+ indices_low = (q_low * (seq_len - 1)).long()
389
+ indices_high = (q_high * (seq_len - 1)).long()
390
+
391
+ c_low = torch.gather(sorted_series, 1, indices_low.unsqueeze(1))
392
+ c_high = torch.gather(sorted_series, 1, indices_high.unsqueeze(1))
393
+
394
+ # Step 3: Compute results for all possible clipping operations
395
+ clip_above = torch.minimum(time_series_batch, c_high)
396
+ clip_below = torch.maximum(time_series_batch, c_low)
397
+
398
+ # Step 4: Select the final result based on the op_mode
399
+ result = torch.where(
400
+ op_mode == 1,
401
+ clip_above,
402
+ torch.where(op_mode == 2, clip_below, time_series_batch),
403
+ )
404
+ augmented_batch = torch.where(
405
+ op_mode == 0,
406
+ time_series_batch,
407
+ result,
408
+ )
409
+
410
+ return augmented_batch.unsqueeze(-1)
411
+
412
+
413
+ class QuantizationAugmenter:
414
+ """
415
+ Applies non-equidistant quantization using a Sobol sequence to generate
416
+ uniformly distributed levels. This implementation is fully vectorized.
417
+ """
418
+
419
+ def __init__(
420
+ self,
421
+ p_quantize: float,
422
+ level_range: Tuple[int, int],
423
+ seed: Optional[int] = None,
424
+ ):
425
+ """
426
+ Initializes the augmenter.
427
+
428
+ Args:
429
+ p_quantize (float): Probability of applying quantization to a series.
430
+ level_range (Tuple[int, int]): Inclusive range [min, max] to sample the
431
+ number of quantization levels from.
432
+ seed (Optional[int]): Seed for the Sobol sequence generator for reproducibility.
433
+ """
434
+ assert 0.0 <= p_quantize <= 1.0, "Probability must be between 0 and 1."
435
+ assert level_range[0] >= 2, "Minimum number of levels must be at least 2."
436
+ assert level_range[0] <= level_range[1], (
437
+ "Min levels cannot be greater than max."
438
+ )
439
+
440
+ self.p_quantize = p_quantize
441
+ self.level_range = level_range
442
+
443
+ # Initialize a SobolEngine. The dimension is the max number of random
444
+ # levels we might need to generate for a single series.
445
+ max_intermediate_levels = self.level_range[1] - 2
446
+ if max_intermediate_levels > 0:
447
+ # SobolEngine must be created on CPU
448
+ self.sobol_engine = SobolEngine(
449
+ dimension=max_intermediate_levels, scramble=True, seed=seed
450
+ )
451
+ else:
452
+ self.sobol_engine = None
453
+
454
+ def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor:
455
+ """
456
+ Applies augmentation in a fully vectorized way on the batch's device.
457
+ Handles input shape (batch, length, 1).
458
+ """
459
+ # Handle input shape (batch, length, 1)
460
+ if time_series_batch.dim() == 3 and time_series_batch.shape[2] == 1:
461
+ is_3d = True
462
+ time_series_squeezed = time_series_batch.squeeze(-1)
463
+ else:
464
+ is_3d = False
465
+ time_series_squeezed = time_series_batch
466
+
467
+ if self.p_quantize == 0 or self.sobol_engine is None:
468
+ return time_series_batch
469
+
470
+ n_series, _ = time_series_squeezed.shape
471
+ device = time_series_squeezed.device
472
+
473
+ # 1. Decide which series to augment
474
+ augment_mask = torch.rand(n_series, device=device) < self.p_quantize
475
+ n_augment = torch.sum(augment_mask)
476
+ if n_augment == 0:
477
+ return time_series_batch
478
+
479
+ series_to_augment = time_series_squeezed[augment_mask]
480
+
481
+ # 2. Determine a variable n_levels for EACH series
482
+ min_l, max_l = self.level_range
483
+ n_levels_per_series = torch.randint(
484
+ min_l, max_l + 1, size=(n_augment,), device=device
485
+ )
486
+ max_levels_in_batch = n_levels_per_series.max().item()
487
+
488
+ # 3. Find min/max for each series
489
+ min_vals = torch.amin(series_to_augment, dim=1, keepdim=True)
490
+ max_vals = torch.amax(series_to_augment, dim=1, keepdim=True)
491
+ value_range = max_vals - min_vals
492
+ is_flat = value_range == 0
493
+
494
+ # 4. Generate quasi-random levels using the Sobol sequence
495
+ num_intermediate_levels = max_levels_in_batch - 2
496
+ if num_intermediate_levels > 0:
497
+ # Draw points from the Sobol engine (on CPU) and move to target device
498
+ sobol_points = self.sobol_engine.draw(n_augment).to(device)
499
+ # We only need the first `num_intermediate_levels` dimensions
500
+ quasi_rand_points = sobol_points[:, :num_intermediate_levels]
501
+ else:
502
+ # Handle case where max_levels_in_batch is 2 (no intermediate points needed)
503
+ quasi_rand_points = torch.empty(n_augment, 0, device=device)
504
+
505
+ scaled_quasi_rand_levels = min_vals + value_range * quasi_rand_points
506
+ level_values = torch.cat([min_vals, max_vals, scaled_quasi_rand_levels], dim=1)
507
+ level_values, _ = torch.sort(level_values, dim=1)
508
+
509
+ # 5. Find the closest level using a mask to ignore padded values
510
+ series_expanded = series_to_augment.unsqueeze(2)
511
+ levels_expanded = level_values.unsqueeze(1)
512
+ diff = torch.abs(series_expanded - levels_expanded)
513
+
514
+ arange_mask = torch.arange(max_levels_in_batch, device=device).unsqueeze(0)
515
+ valid_levels_mask = arange_mask < n_levels_per_series.unsqueeze(1)
516
+ masked_diff = torch.where(valid_levels_mask.unsqueeze(1), diff, float("inf"))
517
+ closest_level_indices = torch.argmin(masked_diff, dim=2)
518
+
519
+ # 6. Gather the results from the original level values
520
+ quantized_subset = torch.gather(level_values, 1, closest_level_indices)
521
+
522
+ # 7. For flat series, revert to their original values
523
+ final_subset = torch.where(is_flat, series_to_augment, quantized_subset)
524
+
525
+ # 8. Place augmented data back into a copy of the original batch
526
+ augmented_batch_squeezed = time_series_squeezed.clone()
527
+ augmented_batch_squeezed[augment_mask] = final_subset
528
+
529
+ # Restore original shape before returning
530
+ if is_3d:
531
+ return augmented_batch_squeezed.unsqueeze(-1)
532
+ else:
533
+ return augmented_batch_squeezed
534
+
535
+
536
+ class MixUpAugmenter:
537
+ """
538
+ Applies mixup augmentation by creating a weighted average of multiple time series.
539
+
540
+ This version includes an option for time-dependent mixup using Simplex Path
541
+ Interpolation, creating a smooth transition between different mixing weights.
542
+ """
543
+
544
+ def __init__(
545
+ self,
546
+ max_n_series_to_combine: int = 10,
547
+ p_combine: float = 0.4,
548
+ p_time_dependent: float = 0.5,
549
+ randomize_k_per_series: bool = True,
550
+ dirichlet_alpha_range: Tuple[float, float] = (0.1, 5.0),
551
+ ):
552
+ """
553
+ Initializes the augmenter.
554
+
555
+ Args:
556
+ max_n_series_to_combine (int): The maximum number of series to combine.
557
+ The actual number k will be sampled from [2, max].
558
+ p_combine (float): The probability of replacing a series with a combination.
559
+ p_time_dependent (float): The probability of using the time-dependent
560
+ simplex path method for a given mixup operation. Defaults to 0.5.
561
+ randomize_k_per_series (bool): If True, each augmented series will be a
562
+ combination of a different number of series (k).
563
+ If False, one k is chosen for the whole batch.
564
+ dirichlet_alpha_range (Tuple[float, float]): The [min, max] range to sample the
565
+ Dirichlet 'alpha' from. A smaller alpha (e.g., 0.2) creates mixes
566
+ dominated by one series. A larger alpha (e.g., 5.0) creates
567
+ more uniform weights.
568
+ """
569
+ assert max_n_series_to_combine >= 2, "Must combine at least 2 series."
570
+ assert 0.0 <= p_combine <= 1.0, "p_combine must be between 0 and 1."
571
+ assert 0.0 <= p_time_dependent <= 1.0, (
572
+ "p_time_dependent must be between 0 and 1."
573
+ )
574
+ assert (
575
+ dirichlet_alpha_range[0] > 0
576
+ and dirichlet_alpha_range[0] <= dirichlet_alpha_range[1]
577
+ )
578
+ self.max_k = max_n_series_to_combine
579
+ self.p_combine = p_combine
580
+ self.p_time_dependent = p_time_dependent
581
+ self.randomize_k = randomize_k_per_series
582
+ self.alpha_range = dirichlet_alpha_range
583
+
584
+ def _sample_alpha(self) -> float:
585
+ log_alpha_min = math.log10(self.alpha_range[0])
586
+ log_alpha_max = math.log10(self.alpha_range[1])
587
+ log_alpha = log_alpha_min + np.random.rand() * (log_alpha_max - log_alpha_min)
588
+ return float(10**log_alpha)
589
+
590
+ def _sample_k(self) -> int:
591
+ return int(torch.randint(2, self.max_k + 1, (1,)).item())
592
+
593
+ def _static_mix(
594
+ self,
595
+ source_series: torch.Tensor,
596
+ alpha: float,
597
+ return_weights: bool = False,
598
+ ):
599
+ """Mixes k source series using a single, static set of Dirichlet weights."""
600
+ k = int(source_series.shape[0])
601
+ device = source_series.device
602
+ concentration = torch.full((k,), float(alpha), device=device)
603
+ weights = torch.distributions.Dirichlet(concentration).sample()
604
+ weights_view = weights.view(k, 1, 1)
605
+ mixed_series = (source_series * weights_view).sum(dim=0, keepdim=True)
606
+ if return_weights:
607
+ return mixed_series, weights
608
+ return mixed_series
609
+
610
+ def _simplex_path_mix(
611
+ self,
612
+ source_series: torch.Tensor,
613
+ alpha: float,
614
+ return_weights: bool = False,
615
+ ):
616
+ """Mixes k series using time-varying weights interpolated along a simplex path."""
617
+ k, length, _ = source_series.shape
618
+ device = source_series.device
619
+
620
+ # 1. Sample two endpoint weight vectors from the Dirichlet distribution
621
+ concentration = torch.full((k,), float(alpha), device=device)
622
+ dirichlet_dist = torch.distributions.Dirichlet(concentration)
623
+ w_start = dirichlet_dist.sample()
624
+ w_end = dirichlet_dist.sample()
625
+
626
+ # 2. Create a linear ramp from 0 to 1
627
+ alpha_ramp = torch.linspace(0, 1, length, device=device)
628
+
629
+ # 3. Interpolate between the endpoint weights over time
630
+ # Reshape for broadcasting: w vectors become [k, 1], ramp becomes [1, length]
631
+ time_varying_weights = w_start.unsqueeze(1) * (
632
+ 1 - alpha_ramp.unsqueeze(0)
633
+ ) + w_end.unsqueeze(1) * alpha_ramp.unsqueeze(0)
634
+ # The result `time_varying_weights` has shape [k, length]
635
+
636
+ # 4. Apply the time-varying weights
637
+ weights_view = time_varying_weights.unsqueeze(-1) # Shape: [k, length, 1]
638
+ mixed_series = (source_series * weights_view).sum(dim=0, keepdim=True)
639
+
640
+ if return_weights:
641
+ return mixed_series, time_varying_weights
642
+ return mixed_series
643
+
644
+ def transform(
645
+ self, time_series_batch: torch.Tensor, return_debug_info: bool = False
646
+ ):
647
+ """
648
+ Applies the mixup augmentation, randomly choosing between static and
649
+ time-dependent mixing methods.
650
+ """
651
+ with torch.no_grad():
652
+ if self.p_combine == 0:
653
+ return (
654
+ (time_series_batch, {}) if return_debug_info else time_series_batch
655
+ )
656
+
657
+ batch_size, _, _ = time_series_batch.shape
658
+ device = time_series_batch.device
659
+
660
+ if batch_size <= self.max_k:
661
+ return (
662
+ (time_series_batch, {}) if return_debug_info else time_series_batch
663
+ )
664
+
665
+ # 1. Decide which series to replace
666
+ augment_mask = torch.rand(batch_size, device=device) < self.p_combine
667
+ indices_to_replace = torch.where(augment_mask)[0]
668
+ n_augment = indices_to_replace.numel()
669
+
670
+ if n_augment == 0:
671
+ return (
672
+ (time_series_batch, {}) if return_debug_info else time_series_batch
673
+ )
674
+
675
+ # 2. Determine k for each series to augment
676
+ if self.randomize_k:
677
+ k_values = torch.randint(2, self.max_k + 1, (n_augment,), device=device)
678
+ else:
679
+ k = self._sample_k()
680
+ k_values = torch.full((n_augment,), k, device=device)
681
+
682
+ # 3. Augment series one by one
683
+ new_series_list = []
684
+ all_batch_indices = torch.arange(batch_size, device=device)
685
+ debug_info = {}
686
+
687
+ for i, target_idx in enumerate(indices_to_replace):
688
+ current_k = k_values[i].item()
689
+
690
+ # Sample source indices
691
+ candidate_mask = all_batch_indices != target_idx
692
+ candidates = all_batch_indices[candidate_mask]
693
+ perm = torch.randperm(candidates.shape[0], device=device)
694
+ source_indices = candidates[perm[:current_k]]
695
+ source_series = time_series_batch[source_indices]
696
+
697
+ alpha = self._sample_alpha()
698
+ mix_type = "static"
699
+
700
+ # Randomly choose between static and time-dependent mixup
701
+ if torch.rand(1).item() < self.p_time_dependent:
702
+ mixed_series, weights = self._simplex_path_mix(
703
+ source_series, alpha=alpha, return_weights=True
704
+ )
705
+ mix_type = "simplex"
706
+ else:
707
+ mixed_series, weights = self._static_mix(
708
+ source_series, alpha=alpha, return_weights=True
709
+ )
710
+
711
+ new_series_list.append(mixed_series)
712
+
713
+ if return_debug_info:
714
+ debug_info[target_idx.item()] = {
715
+ "source_indices": source_indices.cpu().numpy(),
716
+ "weights": weights.cpu().numpy(),
717
+ "alpha": alpha,
718
+ "k": current_k,
719
+ "mix_type": mix_type,
720
+ }
721
+
722
+ # 4. Place augmented series back into a clone of the original batch
723
+ augmented_batch = time_series_batch.clone()
724
+ if new_series_list:
725
+ new_series_tensor = torch.cat(new_series_list, dim=0)
726
+ augmented_batch[indices_to_replace] = new_series_tensor
727
+
728
+ if return_debug_info:
729
+ return augmented_batch.detach(), debug_info
730
+ return augmented_batch.detach()
731
+
732
+
733
+ class TimeFlipAugmenter:
734
+ """
735
+ Applies time-reversal augmentation to a random subset of time series in a batch.
736
+ """
737
+
738
+ def __init__(self, p_flip: float = 0.5):
739
+ """
740
+ Initializes the TimeFlipAugmenter.
741
+
742
+ Args:
743
+ p_flip (float): The probability of flipping a single time series in the batch.
744
+ Defaults to 0.5.
745
+ """
746
+ assert 0.0 <= p_flip <= 1.0, "Probability must be between 0 and 1."
747
+ self.p_flip = p_flip
748
+
749
+ def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor:
750
+ """
751
+ Applies time-reversal augmentation to a batch of time series.
752
+
753
+ Args:
754
+ time_series_batch (torch.Tensor): The input batch of time series with
755
+ shape (batch_size, seq_len, num_channels).
756
+
757
+ Returns:
758
+ torch.Tensor: The batch with some series potentially flipped.
759
+ """
760
+ with torch.no_grad():
761
+ if self.p_flip == 0:
762
+ return time_series_batch
763
+
764
+ batch_size = time_series_batch.shape[0]
765
+ device = time_series_batch.device
766
+
767
+ # 1. Decide which series in the batch to flip
768
+ flip_mask = torch.rand(batch_size, device=device) < self.p_flip
769
+ indices_to_flip = torch.where(flip_mask)[0]
770
+
771
+ if indices_to_flip.numel() == 0:
772
+ return time_series_batch
773
+
774
+ # 2. Select the series to be flipped
775
+ series_to_flip = time_series_batch[indices_to_flip]
776
+
777
+ # 3. Flip them along the time dimension (dim=1)
778
+ flipped_series = torch.flip(series_to_flip, dims=[1])
779
+
780
+ # 4. Create a copy of the batch and place the flipped series into it
781
+ augmented_batch = time_series_batch.clone()
782
+ augmented_batch[indices_to_flip] = flipped_series
783
+
784
+ return augmented_batch
785
+
786
+
787
+ class YFlipAugmenter:
788
+ """
789
+ Applies y-reversal augmentation to a random subset of time series in a batch.
790
+ """
791
+
792
+ def __init__(self, p_flip: float = 0.5):
793
+ """
794
+ Initializes the TimeFlipAugmenter.
795
+
796
+ Args:
797
+ p_flip (float): The probability of flipping a single time series in the batch.
798
+ Defaults to 0.5.
799
+ """
800
+ assert 0.0 <= p_flip <= 1.0, "Probability must be between 0 and 1."
801
+ self.p_flip = p_flip
802
+
803
+ def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor:
804
+ """
805
+ Applies time-reversal augmentation to a batch of time series.
806
+
807
+ Args:
808
+ time_series_batch (torch.Tensor): The input batch of time series with
809
+ shape (batch_size, seq_len, num_channels).
810
+
811
+ Returns:
812
+ torch.Tensor: The batch with some series potentially flipped.
813
+ """
814
+ with torch.no_grad():
815
+ if self.p_flip == 0:
816
+ return time_series_batch
817
+
818
+ batch_size = time_series_batch.shape[0]
819
+ device = time_series_batch.device
820
+
821
+ # 1. Decide which series in the batch to flip
822
+ flip_mask = torch.rand(batch_size, device=device) < self.p_flip
823
+ indices_to_flip = torch.where(flip_mask)[0]
824
+
825
+ if indices_to_flip.numel() == 0:
826
+ return time_series_batch
827
+
828
+ # 2. Select the series to be flipped
829
+ series_to_flip = time_series_batch[indices_to_flip]
830
+
831
+ # 3. Flip them along the time dimension (dim=1)
832
+ flipped_series = -series_to_flip
833
+
834
+ # 4. Create a copy of the batch and place the flipped series into it
835
+ augmented_batch = time_series_batch.clone()
836
+ augmented_batch[indices_to_flip] = flipped_series
837
+
838
+ return augmented_batch
839
+
840
+
841
+ class DifferentialAugmenter:
842
+ """
843
+ Applies calculus-inspired augmentations. This version includes up to the
844
+ fourth derivative and uses nn.Conv1d with built-in 'reflect' padding for
845
+ cleaner and more efficient convolutions.
846
+
847
+ The Gaussian kernel size and sigma for the initial smoothing are randomly
848
+ sampled at every transform() call from user-defined ranges.
849
+ """
850
+
851
+ def __init__(
852
+ self,
853
+ p_transform: float,
854
+ gaussian_kernel_size_range: Tuple[int, int] = (5, 51),
855
+ gaussian_sigma_range: Tuple[float, float] = (2.0, 20.0),
856
+ ):
857
+ """
858
+ Initializes the augmenter.
859
+
860
+ Args:
861
+ p_transform (float): The probability of applying an augmentation to any given
862
+ time series in a batch.
863
+ gaussian_kernel_size_range (Tuple[int, int]): The [min, max] inclusive range
864
+ for the Gaussian kernel size.
865
+ Sizes will be forced to be odd.
866
+ gaussian_sigma_range (Tuple[float, float]): The [min, max] inclusive range
867
+ for the Gaussian sigma.
868
+ """
869
+ self.p_transform = p_transform
870
+ self.kernel_size_range = gaussian_kernel_size_range
871
+ self.sigma_range = gaussian_sigma_range
872
+
873
+ # Validate ranges
874
+ if not (
875
+ self.kernel_size_range[0] <= self.kernel_size_range[1]
876
+ and self.kernel_size_range[0] >= 3
877
+ ):
878
+ raise ValueError(
879
+ "Invalid kernel size range. Ensure min <= max and min >= 3."
880
+ )
881
+ if not (self.sigma_range[0] <= self.sigma_range[1] and self.sigma_range[0] > 0):
882
+ raise ValueError("Invalid sigma range. Ensure min <= max and min > 0.")
883
+
884
+ # Cache for fixed-kernel convolution layers (Sobel, Laplace, etc.)
885
+ self.conv_cache: Dict[Tuple[int, torch.device], Dict[str, nn.Module]] = {}
886
+
887
+ def _create_fixed_kernel_layers(
888
+ self, num_channels: int, device: torch.device
889
+ ) -> dict:
890
+ """
891
+ Creates and configures nn.Conv1d layers for fixed-kernel derivative operations.
892
+ These layers are cached to improve performance.
893
+ """
894
+ sobel_conv = nn.Conv1d(
895
+ in_channels=num_channels,
896
+ out_channels=num_channels,
897
+ kernel_size=3,
898
+ padding="same",
899
+ padding_mode="reflect",
900
+ groups=num_channels,
901
+ bias=False,
902
+ device=device,
903
+ )
904
+ laplace_conv = nn.Conv1d(
905
+ in_channels=num_channels,
906
+ out_channels=num_channels,
907
+ kernel_size=3,
908
+ padding="same",
909
+ padding_mode="reflect",
910
+ groups=num_channels,
911
+ bias=False,
912
+ device=device,
913
+ )
914
+ d3_conv = nn.Conv1d(
915
+ in_channels=num_channels,
916
+ out_channels=num_channels,
917
+ kernel_size=5,
918
+ padding="same",
919
+ padding_mode="reflect",
920
+ groups=num_channels,
921
+ bias=False,
922
+ device=device,
923
+ )
924
+ d4_conv = nn.Conv1d(
925
+ in_channels=num_channels,
926
+ out_channels=num_channels,
927
+ kernel_size=5,
928
+ padding="same",
929
+ padding_mode="reflect",
930
+ groups=num_channels,
931
+ bias=False,
932
+ device=device,
933
+ )
934
+
935
+ sobel_kernel = (
936
+ torch.tensor([-1, 0, 1], device=device, dtype=torch.float32)
937
+ .view(1, 1, -1)
938
+ .repeat(num_channels, 1, 1)
939
+ )
940
+ laplace_kernel = (
941
+ torch.tensor([1, -2, 1], device=device, dtype=torch.float32)
942
+ .view(1, 1, -1)
943
+ .repeat(num_channels, 1, 1)
944
+ )
945
+ d3_kernel = (
946
+ torch.tensor([-1, 2, 0, -2, 1], device=device, dtype=torch.float32)
947
+ .view(1, 1, -1)
948
+ .repeat(num_channels, 1, 1)
949
+ )
950
+ d4_kernel = (
951
+ torch.tensor([1, -4, 6, -4, 1], device=device, dtype=torch.float32)
952
+ .view(1, 1, -1)
953
+ .repeat(num_channels, 1, 1)
954
+ )
955
+
956
+ sobel_conv.weight.data = sobel_kernel
957
+ laplace_conv.weight.data = laplace_kernel
958
+ d3_conv.weight.data = d3_kernel
959
+ d4_conv.weight.data = d4_kernel
960
+
961
+ for layer in [sobel_conv, laplace_conv, d3_conv, d4_conv]:
962
+ layer.weight.requires_grad = False
963
+
964
+ return {
965
+ "sobel": sobel_conv,
966
+ "laplace": laplace_conv,
967
+ "d3": d3_conv,
968
+ "d4": d4_conv,
969
+ }
970
+
971
+ def _create_gaussian_layer(
972
+ self, kernel_size: int, sigma: float, num_channels: int, device: torch.device
973
+ ) -> nn.Module:
974
+ """Creates a single Gaussian convolution layer with the given dynamic parameters."""
975
+ gauss_conv = nn.Conv1d(
976
+ in_channels=num_channels,
977
+ out_channels=num_channels,
978
+ kernel_size=kernel_size,
979
+ padding="same",
980
+ padding_mode="reflect",
981
+ groups=num_channels,
982
+ bias=False,
983
+ device=device,
984
+ )
985
+ ax = torch.arange(
986
+ -(kernel_size // 2),
987
+ kernel_size // 2 + 1,
988
+ device=device,
989
+ dtype=torch.float32,
990
+ )
991
+ gauss_kernel = torch.exp(-0.5 * (ax / sigma) ** 2)
992
+ gauss_kernel /= gauss_kernel.sum()
993
+ gauss_kernel = gauss_kernel.view(1, 1, -1).repeat(num_channels, 1, 1)
994
+ gauss_conv.weight.data = gauss_kernel
995
+ gauss_conv.weight.requires_grad = False
996
+ return gauss_conv
997
+
998
+ def _rescale_signal(
999
+ self, processed_signal: torch.Tensor, original_signal: torch.Tensor
1000
+ ) -> torch.Tensor:
1001
+ """Rescales the processed signal to match the min/max range of the original."""
1002
+ original_min = torch.amin(original_signal, dim=2, keepdim=True)
1003
+ original_max = torch.amax(original_signal, dim=2, keepdim=True)
1004
+ processed_min = torch.amin(processed_signal, dim=2, keepdim=True)
1005
+ processed_max = torch.amax(processed_signal, dim=2, keepdim=True)
1006
+
1007
+ original_range = original_max - original_min
1008
+ processed_range = processed_max - processed_min
1009
+ epsilon = 1e-8
1010
+ rescaled_signal = (
1011
+ (processed_signal - processed_min) / (processed_range + epsilon)
1012
+ ) * original_range + original_min
1013
+ return torch.where(original_range < epsilon, original_signal, rescaled_signal)
1014
+
1015
+ def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor:
1016
+ """Applies a random augmentation to a subset of the batch."""
1017
+ with torch.no_grad():
1018
+ if self.p_transform == 0:
1019
+ return time_series_batch
1020
+
1021
+ batch_size, seq_len, num_channels = time_series_batch.shape
1022
+ device = time_series_batch.device
1023
+
1024
+ augment_mask = torch.rand(batch_size, device=device) < self.p_transform
1025
+ indices_to_augment = torch.where(augment_mask)[0]
1026
+ num_to_augment = indices_to_augment.numel()
1027
+
1028
+ if num_to_augment == 0:
1029
+ return time_series_batch
1030
+
1031
+ # --- 🎲 Randomly sample Gaussian parameters for this call ---
1032
+ min_k, max_k = self.kernel_size_range
1033
+ kernel_size = torch.randint(min_k, max_k + 1, (1,)).item()
1034
+ kernel_size = kernel_size // 2 * 2 + 1 # Ensure kernel size is odd
1035
+
1036
+ min_s, max_s = self.sigma_range
1037
+ sigma = (min_s + (max_s - min_s) * torch.rand(1)).item()
1038
+
1039
+ # --- Get/Create Convolution Layers ---
1040
+ gauss_conv = self._create_gaussian_layer(
1041
+ kernel_size, sigma, num_channels, device
1042
+ )
1043
+
1044
+ cache_key = (num_channels, device)
1045
+ if cache_key not in self.conv_cache:
1046
+ self.conv_cache[cache_key] = self._create_fixed_kernel_layers(
1047
+ num_channels, device
1048
+ )
1049
+ fixed_layers = self.conv_cache[cache_key]
1050
+
1051
+ # --- Apply Augmentations ---
1052
+ subset_to_augment = time_series_batch[indices_to_augment]
1053
+ subset_permuted = subset_to_augment.permute(0, 2, 1)
1054
+
1055
+ op_choices = torch.randint(0, 6, (num_to_augment,), device=device)
1056
+
1057
+ smoothed_subset = gauss_conv(subset_permuted)
1058
+ sobel_on_smoothed = fixed_layers["sobel"](smoothed_subset)
1059
+ laplace_on_smoothed = fixed_layers["laplace"](smoothed_subset)
1060
+ d3_on_smoothed = fixed_layers["d3"](smoothed_subset)
1061
+ d4_on_smoothed = fixed_layers["d4"](smoothed_subset)
1062
+
1063
+ gauss_result = self._rescale_signal(smoothed_subset, subset_permuted)
1064
+ sobel_result = self._rescale_signal(sobel_on_smoothed, subset_permuted)
1065
+ laplace_result = self._rescale_signal(laplace_on_smoothed, subset_permuted)
1066
+ d3_result = self._rescale_signal(d3_on_smoothed, subset_permuted)
1067
+ d4_result = self._rescale_signal(d4_on_smoothed, subset_permuted)
1068
+
1069
+ use_right_integral = torch.rand(num_to_augment, 1, 1, device=device) > 0.5
1070
+ flipped_subset = torch.flip(subset_permuted, dims=[2])
1071
+ right_integral = torch.flip(torch.cumsum(flipped_subset, dim=2), dims=[2])
1072
+ left_integral = torch.cumsum(subset_permuted, dim=2)
1073
+ integral_result = torch.where(
1074
+ use_right_integral, right_integral, left_integral
1075
+ )
1076
+ integral_result_normalized = self._rescale_signal(
1077
+ integral_result, subset_permuted
1078
+ )
1079
+
1080
+ # --- Assemble the results based on op_choices ---
1081
+ op_choices_view = op_choices.view(-1, 1, 1)
1082
+ augmented_subset = torch.where(
1083
+ op_choices_view == 0, gauss_result, subset_permuted
1084
+ )
1085
+ augmented_subset = torch.where(
1086
+ op_choices_view == 1, sobel_result, augmented_subset
1087
+ )
1088
+ augmented_subset = torch.where(
1089
+ op_choices_view == 2, laplace_result, augmented_subset
1090
+ )
1091
+ augmented_subset = torch.where(
1092
+ op_choices_view == 3, integral_result_normalized, augmented_subset
1093
+ )
1094
+ augmented_subset = torch.where(
1095
+ op_choices_view == 4, d3_result, augmented_subset
1096
+ )
1097
+ augmented_subset = torch.where(
1098
+ op_choices_view == 5, d4_result, augmented_subset
1099
+ )
1100
+
1101
+ augmented_subset_final = augmented_subset.permute(0, 2, 1)
1102
+ augmented_batch = time_series_batch.clone()
1103
+ augmented_batch[indices_to_augment] = augmented_subset_final
1104
+
1105
+ return augmented_batch
1106
+
1107
+
1108
+ class RandomConvAugmenter:
1109
+ """
1110
+ Applies a stack of 1-to-N random 1D convolutions to a time series batch.
1111
+
1112
+ This augmenter is inspired by the principles of ROCKET and RandConv,
1113
+ randomizing nearly every aspect of the convolution process to create a
1114
+ highly diverse set of transformations. This version includes multiple
1115
+ kernel generation strategies, random padding modes, and optional non-linearities.
1116
+ """
1117
+
1118
+ def __init__(
1119
+ self,
1120
+ p_transform: float = 0.5,
1121
+ kernel_size_range: Tuple[int, int] = (3, 31),
1122
+ dilation_range: Tuple[int, int] = (1, 8),
1123
+ layer_range: Tuple[int, int] = (1, 3),
1124
+ sigma_range: Tuple[float, float] = (0.5, 5.0),
1125
+ bias_range: Tuple[float, float] = (-0.5, 0.5),
1126
+ ):
1127
+ """
1128
+ Initializes the augmenter.
1129
+
1130
+ Args:
1131
+ p_transform (float): Probability of applying the augmentation to a series.
1132
+ kernel_size_range (Tuple[int, int]): [min, max] range for kernel sizes.
1133
+ Must be odd numbers.
1134
+ dilation_range (Tuple[int, int]): [min, max] range for dilation factors.
1135
+ layer_range (Tuple[int, int]): [min, max] range for the number of
1136
+ stacked convolution layers.
1137
+ sigma_range (Tuple[float, float]): [min, max] range for the sigma of
1138
+ Gaussian kernels.
1139
+ bias_range (Tuple[float, float]): [min, max] range for the bias term.
1140
+ """
1141
+ assert kernel_size_range[0] % 2 == 1 and kernel_size_range[1] % 2 == 1, (
1142
+ "Kernel sizes must be odd."
1143
+ )
1144
+
1145
+ self.p_transform = p_transform
1146
+ self.kernel_size_range = kernel_size_range
1147
+ self.dilation_range = dilation_range
1148
+ self.layer_range = layer_range
1149
+ self.sigma_range = sigma_range
1150
+ self.bias_range = bias_range
1151
+ self.padding_modes = ["reflect", "replicate", "circular"]
1152
+
1153
+ def _rescale_signal(
1154
+ self, processed_signal: torch.Tensor, original_signal: torch.Tensor
1155
+ ) -> torch.Tensor:
1156
+ """Rescales the processed signal to match the min/max range of the original."""
1157
+ original_min = torch.amin(original_signal, dim=-1, keepdim=True)
1158
+ original_max = torch.amax(original_signal, dim=-1, keepdim=True)
1159
+ processed_min = torch.amin(processed_signal, dim=-1, keepdim=True)
1160
+ processed_max = torch.amax(processed_signal, dim=-1, keepdim=True)
1161
+
1162
+ original_range = original_max - original_min
1163
+ processed_range = processed_max - processed_min
1164
+ epsilon = 1e-8
1165
+
1166
+ is_flat = processed_range < epsilon
1167
+
1168
+ rescaled_signal = (
1169
+ (processed_signal - processed_min) / (processed_range + epsilon)
1170
+ ) * original_range + original_min
1171
+
1172
+ original_mean = torch.mean(original_signal, dim=-1, keepdim=True)
1173
+ flat_rescaled = original_mean.expand_as(original_signal)
1174
+
1175
+ return torch.where(is_flat, flat_rescaled, rescaled_signal)
1176
+
1177
+ def _apply_random_conv_stack(self, series: torch.Tensor) -> torch.Tensor:
1178
+ """
1179
+ Applies a randomly configured stack of convolutions to a single time series.
1180
+
1181
+ Args:
1182
+ series (torch.Tensor): A single time series of shape (1, num_channels, seq_len).
1183
+
1184
+ Returns:
1185
+ torch.Tensor: The augmented time series.
1186
+ """
1187
+ num_channels = series.shape[1]
1188
+ device = series.device
1189
+
1190
+ num_layers = torch.randint(
1191
+ self.layer_range[0], self.layer_range[1] + 1, (1,)
1192
+ ).item()
1193
+
1194
+ processed_series = series
1195
+ for i in range(num_layers):
1196
+ # 1. Sample kernel size
1197
+ k_min, k_max = self.kernel_size_range
1198
+ kernel_size = torch.randint(k_min // 2, k_max // 2 + 1, (1,)).item() * 2 + 1
1199
+
1200
+ # 2. Sample dilation
1201
+ d_min, d_max = self.dilation_range
1202
+ dilation = torch.randint(d_min, d_max + 1, (1,)).item()
1203
+
1204
+ # 3. Sample bias
1205
+ b_min, b_max = self.bias_range
1206
+ bias_val = (b_min + (b_max - b_min) * torch.rand(1)).item()
1207
+
1208
+ # 4. Sample padding mode
1209
+ padding_mode = np.random.choice(self.padding_modes)
1210
+
1211
+ conv_layer = nn.Conv1d(
1212
+ in_channels=num_channels,
1213
+ out_channels=num_channels,
1214
+ kernel_size=kernel_size,
1215
+ dilation=dilation,
1216
+ padding="same", # Let PyTorch handle padding calculation
1217
+ padding_mode=padding_mode,
1218
+ groups=num_channels,
1219
+ bias=True,
1220
+ device=device,
1221
+ )
1222
+
1223
+ # 5. Sample kernel weights from a wider variety of types
1224
+ weight_type = torch.randint(0, 4, (1,)).item()
1225
+ if weight_type == 0: # Gaussian kernel
1226
+ s_min, s_max = self.sigma_range
1227
+ sigma = (s_min + (s_max - s_min) * torch.rand(1)).item()
1228
+ ax = torch.arange(
1229
+ -(kernel_size // 2),
1230
+ kernel_size // 2 + 1,
1231
+ device=device,
1232
+ dtype=torch.float32,
1233
+ )
1234
+ kernel = torch.exp(-0.5 * (ax / sigma) ** 2)
1235
+ elif weight_type == 1: # Standard normal kernel
1236
+ kernel = torch.randn(kernel_size, device=device)
1237
+ elif weight_type == 2: # Polynomial kernel
1238
+ coeffs = torch.randn(3, device=device) # a, b, c for ax^2+bx+c
1239
+ x_vals = torch.linspace(-1, 1, kernel_size, device=device)
1240
+ kernel = coeffs[0] * x_vals**2 + coeffs[1] * x_vals + coeffs[2]
1241
+ else: # Noisy Sobel kernel
1242
+ # Ensure kernel is large enough for a Sobel filter
1243
+ actual_kernel_size = 3 if kernel_size < 3 else kernel_size
1244
+ sobel_base = torch.tensor(
1245
+ [-1, 0, 1], dtype=torch.float32, device=device
1246
+ )
1247
+ noise = torch.randn(3, device=device) * 0.1
1248
+ noisy_sobel = sobel_base + noise
1249
+ # Pad if the random kernel size is larger than 3
1250
+ pad_total = actual_kernel_size - 3
1251
+ pad_left = pad_total // 2
1252
+ pad_right = pad_total - pad_left
1253
+ kernel = F.pad(noisy_sobel, (pad_left, pad_right), "constant", 0)
1254
+
1255
+ # 6. Probabilistic normalization
1256
+ if torch.rand(1).item() < 0.8: # 80% chance to normalize
1257
+ kernel /= torch.sum(torch.abs(kernel)) + 1e-8
1258
+
1259
+ kernel = kernel.view(1, 1, -1).repeat(num_channels, 1, 1)
1260
+
1261
+ conv_layer.weight.data = kernel
1262
+ conv_layer.bias.data.fill_(bias_val)
1263
+ conv_layer.weight.requires_grad = False
1264
+ conv_layer.bias.requires_grad = False
1265
+
1266
+ # Apply convolution
1267
+ processed_series = conv_layer(processed_series)
1268
+
1269
+ # 7. Optional non-linearity (not on the last layer)
1270
+ if i < num_layers - 1:
1271
+ activation_type = torch.randint(0, 3, (1,)).item()
1272
+ if activation_type == 1:
1273
+ processed_series = F.relu(processed_series)
1274
+ elif activation_type == 2:
1275
+ processed_series = torch.tanh(processed_series)
1276
+ # if 0, do nothing (linear)
1277
+
1278
+ return processed_series
1279
+
1280
+ def transform(self, time_series_batch: torch.Tensor) -> torch.Tensor:
1281
+ """Applies a random augmentation to a subset of the batch."""
1282
+ with torch.no_grad():
1283
+ if self.p_transform == 0:
1284
+ return time_series_batch
1285
+
1286
+ batch_size, seq_len, num_channels = time_series_batch.shape
1287
+ device = time_series_batch.device
1288
+
1289
+ augment_mask = torch.rand(batch_size, device=device) < self.p_transform
1290
+ indices_to_augment = torch.where(augment_mask)[0]
1291
+ num_to_augment = indices_to_augment.numel()
1292
+
1293
+ if num_to_augment == 0:
1294
+ return time_series_batch
1295
+
1296
+ subset_to_augment = time_series_batch[indices_to_augment]
1297
+
1298
+ subset_permuted = subset_to_augment.permute(0, 2, 1)
1299
+
1300
+ augmented_subset_list = []
1301
+ for i in range(num_to_augment):
1302
+ original_series = subset_permuted[i : i + 1]
1303
+ augmented_series = self._apply_random_conv_stack(original_series)
1304
+
1305
+ rescaled_series = self._rescale_signal(
1306
+ augmented_series.squeeze(0), original_series.squeeze(0)
1307
+ )
1308
+ augmented_subset_list.append(rescaled_series.unsqueeze(0))
1309
+
1310
+ if augmented_subset_list:
1311
+ augmented_subset = torch.cat(augmented_subset_list, dim=0)
1312
+ augmented_subset_final = augmented_subset.permute(0, 2, 1)
1313
+
1314
+ augmented_batch = time_series_batch.clone()
1315
+ augmented_batch[indices_to_augment] = augmented_subset_final
1316
+ return augmented_batch
1317
+ else:
1318
+ return time_series_batch
src/data/batch_composer.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import random
4
+ from typing import Dict, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+
10
+ from src.data.augmentations import (
11
+ NanAugmenter,
12
+ )
13
+ from src.data.constants import DEFAULT_NAN_STATS_PATH, LENGTH_CHOICES, LENGTH_WEIGHTS
14
+ from src.data.containers import BatchTimeSeriesContainer
15
+ from src.data.datasets import CyclicalBatchDataset
16
+ from src.data.frequency import Frequency
17
+ from src.data.scalers import MeanScaler, MedianScaler, MinMaxScaler, RobustScaler
18
+ from src.data.utils import sample_future_length
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class BatchComposer:
25
+ """
26
+ Composes batches from saved generator data according to specified proportions.
27
+ Manages multiple CyclicalBatchDataset instances and creates uniform or mixed batches.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ base_data_dir: str,
33
+ generator_proportions: Optional[Dict[str, float]] = None,
34
+ mixed_batches: bool = True,
35
+ device: Optional[torch.device] = None,
36
+ augmentations: Optional[Dict[str, bool]] = None,
37
+ augmentation_probabilities: Optional[Dict[str, float]] = None,
38
+ nan_stats_path: Optional[str] = None,
39
+ nan_patterns_path: Optional[str] = None,
40
+ global_seed: int = 42,
41
+ chosen_scaler_name: Optional[str] = None,
42
+ rank: int = 0,
43
+ world_size: int = 1,
44
+ ):
45
+ """
46
+ Initialize the BatchComposer.
47
+
48
+ Args:
49
+ base_data_dir: Base directory containing generator subdirectories
50
+ generator_proportions: Dict mapping generator names to proportions
51
+ mixed_batches: If True, create mixed batches; if False, uniform batches
52
+ device: Device to load tensors to
53
+ augmentations: Dict mapping augmentation names to booleans
54
+ augmentation_probabilities: Dict mapping augmentation names to probabilities
55
+ global_seed: Global random seed
56
+ chosen_scaler_name: Name of the scaler that used in training
57
+ rank: Rank of current process for distributed data loading
58
+ world_size: Total number of processes for distributed data loading
59
+ """
60
+ self.base_data_dir = base_data_dir
61
+ self.mixed_batches = mixed_batches
62
+ self.device = device
63
+ self.global_seed = global_seed
64
+ self.nan_stats_path = nan_stats_path
65
+ self.nan_patterns_path = nan_patterns_path
66
+ self.rank = rank
67
+ self.world_size = world_size
68
+ self.augmentation_probabilities = augmentation_probabilities or {
69
+ "noise_augmentation": 0.3,
70
+ "scaler_augmentation": 0.5,
71
+ }
72
+ # Optional preferred scaler name provided by training config
73
+ self.chosen_scaler_name = (
74
+ chosen_scaler_name.lower() if chosen_scaler_name is not None else None
75
+ )
76
+
77
+ # Setup random state
78
+ self.rng = np.random.default_rng(global_seed)
79
+ random.seed(global_seed)
80
+ torch.manual_seed(global_seed)
81
+
82
+ # Setup augmentations
83
+ self._setup_augmentations(augmentations)
84
+
85
+ # Setup generator proportions
86
+ self._setup_proportions(generator_proportions)
87
+
88
+ # Initialize datasets
89
+ self.datasets = self._initialize_datasets()
90
+
91
+ logger.info(
92
+ f"Initialized BatchComposer with {len(self.datasets)} generators, "
93
+ f"mixed_batches={mixed_batches}, proportions={self.generator_proportions}, "
94
+ f"augmentations={self.augmentations}, "
95
+ f"augmentation_probabilities={self.augmentation_probabilities}"
96
+ )
97
+
98
+ def _setup_augmentations(self, augmentations: Optional[Dict[str, bool]]):
99
+ """Setup only the augmentations that should remain online (NaN)."""
100
+ default_augmentations = {
101
+ "nan_augmentation": False,
102
+ "scaler_augmentation": False,
103
+ "length_shortening": False,
104
+ }
105
+
106
+ self.augmentations = augmentations or default_augmentations
107
+
108
+ # Initialize NaN augmenter if needed
109
+ self.nan_augmenter = None
110
+ if self.augmentations.get("nan_augmentation", False):
111
+ stats_path_to_use = self.nan_stats_path or DEFAULT_NAN_STATS_PATH
112
+ stats = json.load(open(stats_path_to_use, "r"))
113
+ self.nan_augmenter = NanAugmenter(
114
+ p_series_has_nan=stats["p_series_has_nan"],
115
+ nan_ratio_distribution=stats["nan_ratio_distribution"],
116
+ nan_length_distribution=stats["nan_length_distribution"],
117
+ nan_patterns_path=self.nan_patterns_path,
118
+ )
119
+
120
+ def _should_apply_scaler_augmentation(self) -> bool:
121
+ """
122
+ Decide whether to apply scaler augmentation for a single series based on
123
+ the boolean toggle and probability from the configuration.
124
+ """
125
+ if not self.augmentations.get("scaler_augmentation", False):
126
+ return False
127
+ probability = float(
128
+ self.augmentation_probabilities.get("scaler_augmentation", 0.0)
129
+ )
130
+ probability = max(0.0, min(1.0, probability))
131
+ return bool(self.rng.random() < probability)
132
+
133
+ def _choose_random_scaler(self) -> Optional[object]:
134
+ """
135
+ Choose a random scaler for augmentation, explicitly avoiding the one that
136
+ is already selected in the training configuration (if any).
137
+
138
+ Returns an instance of the selected scaler or None when no valid option exists.
139
+ """
140
+ chosen: Optional[str] = None
141
+ if self.chosen_scaler_name is not None:
142
+ chosen = self.chosen_scaler_name.strip().lower()
143
+
144
+ candidates = ["custom_robust", "minmax", "median", "mean"]
145
+
146
+ # Remove the chosen scaler from the candidates
147
+ if chosen in candidates:
148
+ candidates = [c for c in candidates if c != chosen]
149
+ if not candidates:
150
+ return None
151
+
152
+ pick = str(self.rng.choice(candidates))
153
+ if pick == "custom_robust":
154
+ return RobustScaler()
155
+ if pick == "minmax":
156
+ return MinMaxScaler()
157
+ if pick == "median":
158
+ return MedianScaler()
159
+ if pick == "mean":
160
+ return MeanScaler()
161
+ return None
162
+
163
+ def _setup_proportions(self, generator_proportions):
164
+ """Setup default or custom generator proportions."""
165
+ default_proportions = {
166
+ "forecast_pfn": 1.0,
167
+ "gp": 1.0,
168
+ "kernel": 1.0,
169
+ "sinewave": 1.0,
170
+ "sawtooth": 1.0,
171
+ "step": 0.1,
172
+ "anomaly": 1.0,
173
+ "spike": 2.0,
174
+ "cauker_univariate": 2.0,
175
+ "cauker_multivariate": 0.00,
176
+ "lmc": 0.00, # multivariate
177
+ "ou_process": 1.0,
178
+ "audio_financial_volatility": 0.1,
179
+ "audio_multi_scale_fractal": 0.1,
180
+ "audio_network_topology": 0.5,
181
+ "audio_stochastic_rhythm": 1.0,
182
+ "augmented_per_sample_2048": 3.0,
183
+ "augmented_temp_batch_2048": 3.0,
184
+ }
185
+ self.generator_proportions = generator_proportions or default_proportions
186
+
187
+ # Normalize proportions
188
+ total = sum(self.generator_proportions.values())
189
+ if total <= 0:
190
+ raise ValueError("Total generator proportions must be positive")
191
+ self.generator_proportions = {
192
+ k: v / total for k, v in self.generator_proportions.items()
193
+ }
194
+
195
+ def _initialize_datasets(self) -> Dict[str, CyclicalBatchDataset]:
196
+ """Initialize CyclicalBatchDataset for each generator with proportion > 0."""
197
+ datasets = {}
198
+
199
+ for generator_name, proportion in self.generator_proportions.items():
200
+ # Only initialize datasets for generators with positive proportion
201
+ if proportion <= 0:
202
+ logger.info(f"Skipping {generator_name} (proportion = {proportion})")
203
+ continue
204
+
205
+ batches_dir = f"{self.base_data_dir}/{generator_name}"
206
+
207
+ try:
208
+ dataset = CyclicalBatchDataset(
209
+ batches_dir=batches_dir,
210
+ generator_type=generator_name,
211
+ device=None,
212
+ prefetch_next=True,
213
+ prefetch_threshold=32,
214
+ rank=self.rank,
215
+ world_size=self.world_size,
216
+ )
217
+ datasets[generator_name] = dataset
218
+ logger.info(
219
+ f"Loaded dataset for {generator_name} (proportion = {proportion})"
220
+ )
221
+
222
+ except Exception as e:
223
+ logger.warning(f"Failed to load dataset for {generator_name}: {e}")
224
+ continue
225
+
226
+ if not datasets:
227
+ raise ValueError(
228
+ f"No valid datasets found in {self.base_data_dir} or all generators have proportion <= 0"
229
+ )
230
+
231
+ return datasets
232
+
233
+ def _convert_sample_to_tensors(
234
+ self, sample: dict, future_length: Optional[int] = None
235
+ ) -> Tuple[torch.Tensor, np.datetime64, Frequency]:
236
+ """
237
+ Convert a sample dict to tensors and metadata.
238
+
239
+ Args:
240
+ sample: Sample dict from CyclicalBatchDataset
241
+ future_length: Desired future length (if None, use default split)
242
+
243
+ Returns:
244
+ Tuple of (history_values, future_values, start, frequency)
245
+ """
246
+ # Handle both old and new data formats
247
+ num_channels = sample.get("num_channels", 1)
248
+ values_data = sample["values"]
249
+ generator_type = sample.get("generator_type", "unknown")
250
+
251
+ if num_channels == 1:
252
+ # Univariate data
253
+ if isinstance(values_data[0], list):
254
+ # New format: [[channel_values]]
255
+ values = torch.tensor(values_data[0], dtype=torch.float32)
256
+ logger.debug(
257
+ f"{generator_type}: Using new univariate format, shape: {values.shape}"
258
+ )
259
+ else:
260
+ # Old format: [values]
261
+ values = torch.tensor(values_data, dtype=torch.float32)
262
+ values = values.unsqueeze(0).unsqueeze(-1) # Shape: [1, seq_len, 1]
263
+ else:
264
+ # Multivariate data (LMC) - new format: [[ch1_values], [ch2_values], ...]
265
+ channel_tensors = []
266
+ for channel_values in values_data:
267
+ channel_tensor = torch.tensor(channel_values, dtype=torch.float32)
268
+ channel_tensors.append(channel_tensor)
269
+
270
+ # Stack channels: [1, seq_len, num_channels]
271
+ values = torch.stack(channel_tensors, dim=-1).unsqueeze(0)
272
+ logger.debug(
273
+ f"{generator_type}: Using multivariate format, {num_channels} channels, shape: {values.shape}"
274
+ )
275
+
276
+ # Handle frequency conversion
277
+ freq_str = sample["frequency"]
278
+ try:
279
+ frequency = Frequency(freq_str)
280
+ except ValueError:
281
+ # Map common frequency strings to Frequency enum
282
+ freq_mapping = {
283
+ "h": Frequency.H,
284
+ "D": Frequency.D,
285
+ "W": Frequency.W,
286
+ "M": Frequency.M,
287
+ "Q": Frequency.Q,
288
+ "A": Frequency.A,
289
+ "Y": Frequency.A, # Annual
290
+ "1min": Frequency.T1,
291
+ "5min": Frequency.T5,
292
+ "10min": Frequency.T10,
293
+ "15min": Frequency.T15,
294
+ "30min": Frequency.T30,
295
+ "s": Frequency.S,
296
+ }
297
+ frequency = freq_mapping.get(freq_str, Frequency.H) # Default to hourly
298
+
299
+ # Handle start timestamp
300
+ if isinstance(sample["start"], pd.Timestamp):
301
+ start = sample["start"].to_numpy()
302
+ else:
303
+ start = np.datetime64(sample["start"])
304
+
305
+ return values, start, frequency
306
+
307
+ def _effective_proportions_for_length(
308
+ self, total_length_for_batch: int
309
+ ) -> Dict[str, float]:
310
+ """
311
+ Build a simple, length-aware proportion map for the current batch.
312
+
313
+ Rules:
314
+ - For generators named 'augmented{L}', keep only the one matching the
315
+ chosen length L; zero out others.
316
+ - Keep non-augmented generators as-is.
317
+ - Drop generators that are unavailable (not loaded) or zero-weight.
318
+ - If nothing remains, fall back to 'augmented{L}' if available, else any dataset.
319
+ - Normalize the final map to sum to 1.
320
+ """
321
+
322
+ def augmented_length_from_name(name: str) -> Optional[int]:
323
+ if not name.startswith("augmented"):
324
+ return None
325
+ suffix = name[len("augmented") :]
326
+ if not suffix:
327
+ return None
328
+ try:
329
+ return int(suffix)
330
+ except ValueError:
331
+ return None
332
+
333
+ # 1) Adjust proportions with the length-aware rule
334
+ adjusted: Dict[str, float] = {}
335
+ for name, proportion in self.generator_proportions.items():
336
+ aug_len = augmented_length_from_name(name)
337
+ if aug_len is None:
338
+ adjusted[name] = proportion
339
+ else:
340
+ adjusted[name] = (
341
+ proportion if aug_len == total_length_for_batch else 0.0
342
+ )
343
+
344
+ # 2) Keep only available, positive-weight datasets
345
+ adjusted = {
346
+ name: p for name, p in adjusted.items() if name in self.datasets and p > 0.0
347
+ }
348
+
349
+ # 3) Fallback if empty
350
+ if not adjusted:
351
+ preferred = f"augmented{total_length_for_batch}"
352
+ if preferred in self.datasets:
353
+ adjusted = {preferred: 1.0}
354
+ elif self.datasets:
355
+ # Choose any available dataset deterministically (first key)
356
+ first_key = next(iter(self.datasets.keys()))
357
+ adjusted = {first_key: 1.0}
358
+ else:
359
+ raise ValueError("No datasets available to create batch")
360
+
361
+ # 4) Normalize
362
+ total = sum(adjusted.values())
363
+ return {name: p / total for name, p in adjusted.items()}
364
+
365
+ def _compute_sample_counts_for_batch(
366
+ self, proportions: Dict[str, float], batch_size: int
367
+ ) -> Dict[str, int]:
368
+ """
369
+ Convert a proportion map into integer sample counts that sum to batch_size.
370
+
371
+ Strategy: allocate floor(batch_size * p) to each generator in order, and let the
372
+ last generator absorb any remainder to ensure the total matches exactly.
373
+ """
374
+ counts: Dict[str, int] = {}
375
+ remaining = batch_size
376
+ names = list(proportions.keys())
377
+ values = list(proportions.values())
378
+ for index, (name, p) in enumerate(zip(names, values)):
379
+ if index == len(names) - 1:
380
+ counts[name] = remaining
381
+ else:
382
+ n = int(batch_size * p)
383
+ counts[name] = n
384
+ remaining -= n
385
+ return counts
386
+
387
+ def _calculate_generator_samples(self, batch_size: int) -> Dict[str, int]:
388
+ """
389
+ Calculate the number of samples each generator should contribute.
390
+
391
+ Args:
392
+ batch_size: Total batch size
393
+
394
+ Returns:
395
+ Dict mapping generator names to sample counts
396
+ """
397
+ generator_samples = {}
398
+ remaining_samples = batch_size
399
+
400
+ generators = list(self.generator_proportions.keys())
401
+ proportions = list(self.generator_proportions.values())
402
+
403
+ # Calculate base samples for each generator
404
+ for i, (generator, proportion) in enumerate(zip(generators, proportions)):
405
+ if generator not in self.datasets:
406
+ continue
407
+
408
+ if i == len(generators) - 1: # Last generator gets remaining samples
409
+ samples = remaining_samples
410
+ else:
411
+ samples = int(batch_size * proportion)
412
+ remaining_samples -= samples
413
+ generator_samples[generator] = samples
414
+
415
+ return generator_samples
416
+
417
+ def create_batch(
418
+ self,
419
+ batch_size: int = 128,
420
+ seed: Optional[int] = None,
421
+ future_length: Optional[int] = None,
422
+ ) -> Tuple[BatchTimeSeriesContainer, str]:
423
+ """
424
+ Create a batch of the specified size.
425
+
426
+ Args:
427
+ batch_size: Size of the batch to create
428
+ seed: Random seed for this batch
429
+ future_length: Fixed future length to use. If None, samples from gift_eval range
430
+
431
+ Returns:
432
+ Tuple of (batch_container, generator_info)
433
+ """
434
+ if seed is not None:
435
+ batch_rng = np.random.default_rng(seed)
436
+ random.seed(seed)
437
+ else:
438
+ batch_rng = self.rng
439
+
440
+ if self.mixed_batches:
441
+ return self._create_mixed_batch(batch_size, future_length)
442
+ else:
443
+ return self._create_uniform_batch(batch_size, batch_rng, future_length)
444
+
445
+ def _create_mixed_batch(
446
+ self, batch_size: int, future_length: Optional[int] = None
447
+ ) -> Tuple[BatchTimeSeriesContainer, str]:
448
+ """Create a mixed batch with samples from multiple generators, rejecting NaNs."""
449
+
450
+ # Choose total length for this batch; respect length_shortening flag.
451
+ # When disabled, always use the maximum to avoid shortening.
452
+ if self.augmentations.get("length_shortening", False):
453
+ lengths = list(LENGTH_WEIGHTS.keys())
454
+ probs = list(LENGTH_WEIGHTS.values())
455
+ total_length_for_batch = int(self.rng.choice(lengths, p=probs))
456
+ else:
457
+ total_length_for_batch = int(max(LENGTH_CHOICES))
458
+
459
+ if future_length is None:
460
+ prediction_length = int(
461
+ sample_future_length(
462
+ range="gift_eval", total_length=total_length_for_batch
463
+ )
464
+ )
465
+ else:
466
+ prediction_length = future_length
467
+
468
+ history_length = total_length_for_batch - prediction_length
469
+
470
+ # Calculate samples per generator using simple, per-batch length-aware proportions
471
+ effective_props = self._effective_proportions_for_length(total_length_for_batch)
472
+ generator_samples = self._compute_sample_counts_for_batch(
473
+ effective_props, batch_size
474
+ )
475
+
476
+ all_values = []
477
+ all_starts = []
478
+ all_frequencies = []
479
+ actual_proportions = {}
480
+
481
+ # Collect valid samples from each generator using batched fetches to reduce I/O overhead
482
+ for generator_name, num_samples in generator_samples.items():
483
+ if num_samples == 0 or generator_name not in self.datasets:
484
+ continue
485
+
486
+ dataset = self.datasets[generator_name]
487
+
488
+ # Lists to hold valid samples for the current generator
489
+ generator_values = []
490
+ generator_starts = []
491
+ generator_frequencies = []
492
+
493
+ # Loop until we have collected the required number of VALID samples
494
+ max_attempts = 50
495
+ attempts = 0
496
+ while len(generator_values) < num_samples and attempts < max_attempts:
497
+ attempts += 1
498
+ # Fetch a batch larger than needed to reduce round-trips
499
+ need = num_samples - len(generator_values)
500
+ fetch_n = max(need * 2, 8)
501
+ samples = dataset.get_samples(fetch_n)
502
+
503
+ for sample in samples:
504
+ if len(generator_values) >= num_samples:
505
+ break
506
+
507
+ values, sample_start, sample_freq = self._convert_sample_to_tensors(
508
+ sample, future_length
509
+ )
510
+
511
+ # Skip if NaNs exist (we inject NaNs later in history only)
512
+ if torch.isnan(values).any():
513
+ continue
514
+
515
+ # Resize to target batch length when longer
516
+ if total_length_for_batch < values.shape[1]:
517
+ strategy = self.rng.choice(["cut", "subsample"]) # 50/50
518
+ if strategy == "cut":
519
+ max_start_idx = values.shape[1] - total_length_for_batch
520
+ start_idx = int(self.rng.integers(0, max_start_idx + 1))
521
+ values = values[
522
+ :, start_idx : start_idx + total_length_for_batch, :
523
+ ]
524
+ else:
525
+ indices = np.linspace(
526
+ 0,
527
+ values.shape[1] - 1,
528
+ total_length_for_batch,
529
+ dtype=int,
530
+ )
531
+ values = values[:, indices, :]
532
+
533
+ # Optionally apply scaler augmentation according to configuration
534
+ if self._should_apply_scaler_augmentation():
535
+ scaler = self._choose_random_scaler()
536
+ if scaler is not None:
537
+ values = scaler.scale(
538
+ values, scaler.compute_statistics(values)
539
+ )
540
+
541
+ generator_values.append(values)
542
+ generator_starts.append(sample_start)
543
+ generator_frequencies.append(sample_freq)
544
+
545
+ if len(generator_values) < num_samples:
546
+ logger.warning(
547
+ f"Generator {generator_name}: collected {len(generator_values)}/{num_samples} after {attempts} attempts"
548
+ )
549
+
550
+ # Add the collected valid samples to the main batch lists
551
+ if generator_values:
552
+ all_values.extend(generator_values)
553
+ all_starts.extend(generator_starts)
554
+ all_frequencies.extend(generator_frequencies)
555
+ actual_proportions[generator_name] = len(generator_values)
556
+
557
+ if not all_values:
558
+ raise RuntimeError(
559
+ "No valid samples could be collected from any generator."
560
+ )
561
+
562
+ combined_values = torch.cat(all_values, dim=0)
563
+ # Split into history and future
564
+ combined_history = combined_values[:, :history_length, :]
565
+ combined_future = combined_values[
566
+ :, history_length : history_length + prediction_length, :
567
+ ]
568
+
569
+ if self.nan_augmenter is not None:
570
+ combined_history = self.nan_augmenter.transform(combined_history)
571
+
572
+ # Create container
573
+ container = BatchTimeSeriesContainer(
574
+ history_values=combined_history,
575
+ future_values=combined_future,
576
+ start=all_starts,
577
+ frequency=all_frequencies,
578
+ )
579
+
580
+ return container, "MixedBatch"
581
+
582
+ def _create_uniform_batch(
583
+ self,
584
+ batch_size: int,
585
+ batch_rng: np.random.Generator,
586
+ future_length: Optional[int] = None,
587
+ ) -> Tuple[BatchTimeSeriesContainer, str]:
588
+ """Create a uniform batch with samples from a single generator."""
589
+
590
+ # Select generator based on proportions
591
+ generators = list(self.datasets.keys())
592
+ proportions = [self.generator_proportions[gen] for gen in generators]
593
+ selected_generator = batch_rng.choice(generators, p=proportions)
594
+
595
+ # Sample future length
596
+ if future_length is None:
597
+ future_length = sample_future_length(range="gift_eval")
598
+
599
+ # Get samples from selected generator
600
+ dataset = self.datasets[selected_generator]
601
+ samples = dataset.get_samples(batch_size)
602
+
603
+ all_history_values = []
604
+ all_future_values = []
605
+ all_starts = []
606
+ all_frequencies = []
607
+
608
+ for sample in samples:
609
+ values, sample_start, sample_freq = self._convert_sample_to_tensors(
610
+ sample, future_length
611
+ )
612
+
613
+ total_length = values.shape[1]
614
+ history_length = max(1, total_length - future_length)
615
+
616
+ # Optionally apply scaler augmentation according to configuration
617
+ if self._should_apply_scaler_augmentation():
618
+ scaler = self._choose_random_scaler()
619
+ if scaler is not None:
620
+ values = scaler.scale(values, scaler.compute_statistics(values))
621
+
622
+ # Reshape to [1, seq_len, 1] for single sample
623
+ hist_vals = values[:, :history_length, :]
624
+ fut_vals = values[:, history_length : history_length + future_length, :]
625
+
626
+ all_history_values.append(hist_vals)
627
+ all_future_values.append(fut_vals)
628
+ all_starts.append(sample_start)
629
+ all_frequencies.append(sample_freq)
630
+
631
+ # Combine samples
632
+ combined_history = torch.cat(all_history_values, dim=0)
633
+ combined_future = torch.cat(all_future_values, dim=0)
634
+
635
+ # Create container
636
+ container = BatchTimeSeriesContainer(
637
+ history_values=combined_history,
638
+ future_values=combined_future,
639
+ start=all_starts,
640
+ frequency=all_frequencies,
641
+ )
642
+
643
+ return container, selected_generator
644
+
645
+ def get_dataset_info(self) -> Dict[str, dict]:
646
+ """Get information about all datasets."""
647
+ info = {}
648
+ for name, dataset in self.datasets.items():
649
+ info[name] = dataset.get_info()
650
+ return info
651
+
652
+ def get_generator_info(self) -> Dict[str, any]:
653
+ """Get information about the composer configuration."""
654
+ return {
655
+ "mixed_batches": self.mixed_batches,
656
+ "generator_proportions": self.generator_proportions,
657
+ "active_generators": list(self.datasets.keys()),
658
+ "total_generators": len(self.datasets),
659
+ "augmentations": self.augmentations,
660
+ "augmentation_probabilities": self.augmentation_probabilities,
661
+ "nan_augmenter_enabled": self.nan_augmenter is not None,
662
+ }
663
+
664
+
665
+ class ComposedDataset(torch.utils.data.Dataset):
666
+ """
667
+ PyTorch Dataset wrapper around BatchComposer for training pipeline integration.
668
+ """
669
+
670
+ def __init__(
671
+ self,
672
+ batch_composer: BatchComposer,
673
+ num_batches_per_epoch: int = 100,
674
+ batch_size: int = 128,
675
+ ):
676
+ """
677
+ Initialize the dataset.
678
+
679
+ Args:
680
+ batch_composer: The BatchComposer instance
681
+ num_batches_per_epoch: Number of batches to generate per epoch
682
+ batch_size: Size of each batch
683
+ """
684
+ self.batch_composer = batch_composer
685
+ self.num_batches_per_epoch = num_batches_per_epoch
686
+ self.batch_size = batch_size
687
+
688
+ def __len__(self) -> int:
689
+ return self.num_batches_per_epoch
690
+
691
+ def __getitem__(self, idx: int) -> BatchTimeSeriesContainer:
692
+ """
693
+ Get a batch by index.
694
+
695
+ Args:
696
+ idx: Batch index (used as seed for reproducibility)
697
+
698
+ Returns:
699
+ BatchTimeSeriesContainer
700
+ """
701
+ # Use index as seed for reproducible batches
702
+ batch, _ = self.batch_composer.create_batch(
703
+ batch_size=self.batch_size, seed=self.batch_composer.global_seed + idx
704
+ )
705
+ return batch
src/data/constants.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import date
2
+ from typing import Dict
3
+
4
+ import numpy as np
5
+
6
+ DEFAULT_START_DATE = date(1700, 1, 1)
7
+ DEFAULT_END_DATE = date(2200, 1, 1)
8
+ BASE_START_DATE = np.datetime64(DEFAULT_START_DATE)
9
+ BASE_END_DATE = np.datetime64(DEFAULT_END_DATE)
10
+
11
+ # Maximum years to prevent timestamp overflow
12
+ MAX_YEARS = 500
13
+
14
+ LENGTH_CHOICES = [128, 256, 512, 1024, 1536, 2048]
15
+
16
+ DEFAULT_NAN_STATS_PATH: str = "./data/nan_stats.json"
17
+
18
+ LENGTH_WEIGHTS: Dict[int, float] = {
19
+ 128: 0.05,
20
+ 256: 0.10,
21
+ 512: 0.10,
22
+ 1024: 0.10,
23
+ 1536: 0.15,
24
+ 2048: 0.50,
25
+ }
src/data/containers.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from src.data.frequency import Frequency
8
+
9
+
10
+ @dataclass
11
+ class BatchTimeSeriesContainer:
12
+ """
13
+ Container for a batch of multivariate time series data and their associated features.
14
+
15
+ Attributes:
16
+ history_values: Tensor of historical observations.
17
+ Shape: [batch_size, seq_len, num_channels]
18
+ future_values: Tensor of future observations to predict.
19
+ Shape: [batch_size, pred_len, num_channels]
20
+ start: Timestamp of the first history value.
21
+ Type: List[np.datetime64]
22
+ frequency: Frequency of the time series.
23
+ Type: List[Frequency]
24
+ history_mask: Optional boolean/float tensor indicating missing entries in history_values across channels.
25
+ Shape: [batch_size, seq_len]
26
+ future_mask: Optional boolean/float tensor indicating missing entries in future_values across channels.
27
+ Shape: [batch_size, pred_len]
28
+ """
29
+
30
+ history_values: torch.Tensor
31
+ future_values: torch.Tensor
32
+ start: List[np.datetime64]
33
+ frequency: List[Frequency]
34
+
35
+ history_mask: Optional[torch.Tensor] = None
36
+ future_mask: Optional[torch.Tensor] = None
37
+
38
+ def __post_init__(self):
39
+ """Validate all tensor shapes and consistency."""
40
+ # --- Tensor Type Checks ---
41
+ if not isinstance(self.history_values, torch.Tensor):
42
+ raise TypeError("history_values must be a torch.Tensor")
43
+ if not isinstance(self.future_values, torch.Tensor):
44
+ raise TypeError("future_values must be a torch.Tensor")
45
+ if not isinstance(self.start, list) or not all(
46
+ isinstance(x, np.datetime64) for x in self.start
47
+ ):
48
+ raise TypeError("start must be a List[np.datetime64]")
49
+ if not isinstance(self.frequency, list) or not all(
50
+ isinstance(x, Frequency) for x in self.frequency
51
+ ):
52
+ raise TypeError("frequency must be a List[Frequency]")
53
+
54
+ batch_size, seq_len, num_channels = self.history_values.shape
55
+ pred_len = self.future_values.shape[1]
56
+
57
+ # --- Core Shape Checks ---
58
+ if self.future_values.shape[0] != batch_size:
59
+ raise ValueError("Batch size mismatch between history and future_values")
60
+ if self.future_values.shape[2] != num_channels:
61
+ raise ValueError("Channel size mismatch between history and future_values")
62
+
63
+ # --- Optional Mask Checks ---
64
+ if self.history_mask is not None:
65
+ if not isinstance(self.history_mask, torch.Tensor):
66
+ raise TypeError("history_mask must be a Tensor or None")
67
+ if self.history_mask.shape[:2] != (batch_size, seq_len):
68
+ raise ValueError(
69
+ f"Shape mismatch in history_mask: {self.history_mask.shape[:2]} vs {(batch_size, seq_len)}"
70
+ )
71
+
72
+ if self.future_mask is not None:
73
+ if not isinstance(self.future_mask, torch.Tensor):
74
+ raise TypeError("future_mask must be a Tensor or None")
75
+ if not (
76
+ self.future_mask.shape == (batch_size, pred_len)
77
+ or self.future_mask.shape == self.future_values.shape
78
+ ):
79
+ raise ValueError(
80
+ f"Shape mismatch in future_mask: expected {(batch_size, pred_len)} or {self.future_values.shape}, got {self.future_mask.shape}"
81
+ )
82
+
83
+ def to_device(
84
+ self, device: torch.device, attributes: Optional[List[str]] = None
85
+ ) -> None:
86
+ """
87
+ Move specified tensors to the target device in place.
88
+
89
+ Args:
90
+ device: Target device (e.g., 'cpu', 'cuda').
91
+ attributes: Optional list of attribute names to move. If None, move all tensors.
92
+
93
+ Raises:
94
+ ValueError: If an invalid attribute is specified or device transfer fails.
95
+ """
96
+ all_tensors = {
97
+ "history_values": self.history_values,
98
+ "future_values": self.future_values,
99
+ "history_mask": self.history_mask,
100
+ "future_mask": self.future_mask,
101
+ }
102
+
103
+ if attributes is None:
104
+ attributes = [k for k, v in all_tensors.items() if v is not None]
105
+
106
+ for attr in attributes:
107
+ if attr not in all_tensors:
108
+ raise ValueError(f"Invalid attribute: {attr}")
109
+ if all_tensors[attr] is not None:
110
+ setattr(self, attr, all_tensors[attr].to(device))
111
+
112
+ def to(self, device: torch.device, attributes: Optional[List[str]] = None):
113
+ """
114
+ Alias for to_device method for consistency with PyTorch conventions.
115
+
116
+ Args:
117
+ device: Target device (e.g., 'cpu', 'cuda').
118
+ attributes: Optional list of attribute names to move. If None, move all tensors.
119
+ """
120
+ self.to_device(device, attributes)
121
+ return self
122
+
123
+ @property
124
+ def batch_size(self) -> int:
125
+ return self.history_values.shape[0]
126
+
127
+ @property
128
+ def history_length(self) -> int:
129
+ return self.history_values.shape[1]
130
+
131
+ @property
132
+ def future_length(self) -> int:
133
+ return self.future_values.shape[1]
134
+
135
+ @property
136
+ def num_channels(self) -> int:
137
+ return self.history_values.shape[2]
138
+
139
+
140
+ @dataclass
141
+ class TimeSeriesContainer:
142
+ """
143
+ Container for batch of time series data without explicit history/future split.
144
+
145
+ This container is used for storing generated synthetic time series data where
146
+ the entire series is treated as a single entity, typically for further processing
147
+ or splitting into history/future components later.
148
+
149
+ Attributes:
150
+ values: np.ndarray of time series values.
151
+ Shape: [batch_size, seq_len, num_channels] for multivariate series
152
+ [batch_size, seq_len] for univariate series
153
+ start: List of start timestamps for each series in the batch.
154
+ Type: List[np.datetime64], length should match batch_size
155
+ frequency: List of frequency for each series in the batch.
156
+ Type: List[Frequency], length should match batch_size
157
+ """
158
+
159
+ values: np.ndarray
160
+ start: List[np.datetime64]
161
+ frequency: List[Frequency]
162
+
163
+ def __post_init__(self):
164
+ """Validate all shapes and consistency."""
165
+ # --- Numpy Type Checks ---
166
+ if not isinstance(self.values, np.ndarray):
167
+ raise TypeError("values must be a np.ndarray")
168
+ if not isinstance(self.start, list) or not all(
169
+ isinstance(x, np.datetime64) for x in self.start
170
+ ):
171
+ raise TypeError("start must be a List[np.datetime64]")
172
+ if not isinstance(self.frequency, list) or not all(
173
+ isinstance(x, Frequency) for x in self.frequency
174
+ ):
175
+ raise TypeError("frequency must be a List[Frequency]")
176
+
177
+ # --- Shape and Length Consistency Checks ---
178
+ if len(self.values.shape) < 2 or len(self.values.shape) > 3:
179
+ raise ValueError(
180
+ f"values must have 2 or 3 dimensions [batch_size, seq_len] or [batch_size, seq_len, num_channels], got shape {self.values.shape}"
181
+ )
182
+
183
+ batch_size = self.values.shape[0]
184
+
185
+ if len(self.start) != batch_size:
186
+ raise ValueError(
187
+ f"Length of start ({len(self.start)}) must match batch_size ({batch_size})"
188
+ )
189
+ if len(self.frequency) != batch_size:
190
+ raise ValueError(
191
+ f"Length of frequency ({len(self.frequency)}) must match batch_size ({batch_size})"
192
+ )
193
+
194
+ @property
195
+ def batch_size(self) -> int:
196
+ return self.values.shape[0]
197
+
198
+ @property
199
+ def seq_length(self) -> int:
200
+ return self.values.shape[1]
201
+
202
+ @property
203
+ def num_channels(self) -> int:
204
+ return self.values.shape[2] if len(self.values.shape) == 3 else 1
205
+
206
+
207
+
208
+ @dataclass
209
+ class TimeSeriesContainer:
210
+ """
211
+ Container for batch of time series data without explicit history/future split.
212
+
213
+ This container is used for storing generated synthetic time series data where
214
+ the entire series is treated as a single entity, typically for further processing
215
+ or splitting into history/future components later.
216
+
217
+ Attributes:
218
+ values: np.ndarray of time series values.
219
+ Shape: [batch_size, seq_len, num_channels] for multivariate series
220
+ [batch_size, seq_len] for univariate series
221
+ start: List of start timestamps for each series in the batch.
222
+ Type: List[np.datetime64], length should match batch_size
223
+ frequency: List of frequency for each series in the batch.
224
+ Type: List[Frequency], length should match batch_size
225
+ """
226
+
227
+ values: np.ndarray
228
+ start: List[np.datetime64]
229
+ frequency: List[Frequency]
230
+
231
+ def __post_init__(self):
232
+ """Validate all shapes and consistency."""
233
+ # --- Numpy Type Checks ---
234
+ if not isinstance(self.values, np.ndarray):
235
+ raise TypeError("values must be a np.ndarray")
236
+ if not isinstance(self.start, list) or not all(
237
+ isinstance(x, np.datetime64) for x in self.start
238
+ ):
239
+ raise TypeError("start must be a List[np.datetime64]")
240
+ if not isinstance(self.frequency, list) or not all(
241
+ isinstance(x, Frequency) for x in self.frequency
242
+ ):
243
+ raise TypeError("frequency must be a List[Frequency]")
244
+
245
+ # --- Shape and Length Consistency Checks ---
246
+ if len(self.values.shape) < 2 or len(self.values.shape) > 3:
247
+ raise ValueError(
248
+ f"values must have 2 or 3 dimensions [batch_size, seq_len] or [batch_size, seq_len, num_channels], got shape {self.values.shape}"
249
+ )
250
+
251
+ batch_size = self.values.shape[0]
252
+
253
+ if len(self.start) != batch_size:
254
+ raise ValueError(
255
+ f"Length of start ({len(self.start)}) must match batch_size ({batch_size})"
256
+ )
257
+ if len(self.frequency) != batch_size:
258
+ raise ValueError(
259
+ f"Length of frequency ({len(self.frequency)}) must match batch_size ({batch_size})"
260
+ )
261
+
262
+ @property
263
+ def batch_size(self) -> int:
264
+ return self.values.shape[0]
265
+
266
+ @property
267
+ def seq_length(self) -> int:
268
+ return self.values.shape[1]
269
+
270
+ @property
271
+ def num_channels(self) -> int:
272
+ return self.values.shape[2] if len(self.values.shape) == 3 else 1
src/data/datasets.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ from typing import List, Optional
5
+
6
+ import pyarrow.feather as feather
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class CyclicalBatchDataset:
13
+ """
14
+ Dataset class that loads saved batches from continuous generation script.
15
+ Maintains a pointer and provides cyclical access to individual samples.
16
+ Includes enhanced logging to track data shard cycling during training.
17
+ Supports per-rank file sharding for large-scale distributed training.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ batches_dir: str,
23
+ generator_type: str,
24
+ device: Optional[torch.device] = None,
25
+ prefetch_next: bool = True,
26
+ prefetch_threshold: int = 32,
27
+ rank: int = 0,
28
+ world_size: int = 1,
29
+ ):
30
+ """
31
+ Initialize the cyclical batch dataset.
32
+
33
+ Args:
34
+ batches_dir: Directory containing the batch arrow files
35
+ generator_type: Type of generator (for logging)
36
+ device: Device to load tensors to
37
+ prefetch_next: Whether to prefetch the next batch
38
+ prefetch_threshold: Number of remaining samples to trigger prefetching
39
+ rank: Rank of the current process (for file sharding)
40
+ world_size: Total number of processes (for file sharding)
41
+ """
42
+ self.batches_dir = batches_dir
43
+ self.generator_type = generator_type
44
+ self.device = device
45
+ self.prefetch_next = prefetch_next
46
+ self.prefetch_threshold = prefetch_threshold
47
+ self.rank = rank
48
+ self.world_size = world_size
49
+
50
+ self.batch_files = self._find_batch_files()
51
+ if not self.batch_files:
52
+ raise ValueError(f"No batch files found in {batches_dir}")
53
+
54
+ # --- State tracking ---
55
+ self.current_batch_idx = 0
56
+ self.current_sample_idx = 0
57
+ self.current_batch_data = None
58
+ self.next_batch_data = None
59
+ self.prefetching_in_progress = False
60
+
61
+ # --- NEW: Logging and cycle tracking ---
62
+ self.visited_batch_indices = set()
63
+ self.full_cycles_completed = 0
64
+
65
+ # Load first batch and update tracking
66
+ self._load_current_batch()
67
+ self.visited_batch_indices.add(self.current_batch_idx)
68
+
69
+ logger.info(
70
+ f"Initialized '{self.generator_type}' dataset with {len(self.batch_files)} batches. "
71
+ f"Current batch file: '{os.path.basename(self.batch_files[self.current_batch_idx])}' "
72
+ f"has {len(self.current_batch_data)} samples."
73
+ )
74
+
75
+ def _find_batch_files(self) -> List[str]:
76
+ """
77
+ Find and sort batch files with per-rank sharding for distributed training.
78
+
79
+ Each rank gets a disjoint subset of files to minimize I/O contention
80
+ when scaling to hundreds of GPUs.
81
+ """
82
+ import glob
83
+
84
+ pattern = os.path.join(self.batches_dir, "batch_*.arrow")
85
+ all_files = sorted(glob.glob(pattern)) # Sort for deterministic sharding
86
+
87
+ if not all_files:
88
+ return []
89
+
90
+ # Shard files across ranks: each rank gets every world_size-th file
91
+ # Example with 4 ranks: rank0=[0,4,8,...], rank1=[1,5,9,...], etc.
92
+ rank_files = [
93
+ f for i, f in enumerate(all_files) if i % self.world_size == self.rank
94
+ ]
95
+
96
+ # Shuffle only within this rank's shard for variety
97
+ random.shuffle(rank_files)
98
+
99
+ logger.info(
100
+ f"[Rank {self.rank}] '{self.generator_type}': Sharded {len(all_files)} files → "
101
+ f"{len(rank_files)} files for this rank ({len(rank_files) / len(all_files) * 100:.1f}%)"
102
+ )
103
+
104
+ return rank_files
105
+
106
+ def _load_batch_from_file(self, batch_file: str) -> List[dict]:
107
+ """Load a batch from arrow file."""
108
+ try:
109
+ table = feather.read_table(batch_file)
110
+ has_num_channels = "num_channels" in table.column_names
111
+ batch_data = []
112
+ for i in range(len(table)):
113
+ row = {
114
+ "series_id": table["series_id"][i].as_py(),
115
+ "values": table["values"][i].as_py(),
116
+ "length": table["length"][i].as_py(),
117
+ "generator_type": table["generator_type"][i].as_py(),
118
+ "start": table["start"][i].as_py(),
119
+ "frequency": table["frequency"][i].as_py(),
120
+ "generation_timestamp": table["generation_timestamp"][i].as_py(),
121
+ }
122
+ if has_num_channels:
123
+ row["num_channels"] = table["num_channels"][i].as_py()
124
+ else:
125
+ row["num_channels"] = 1
126
+ batch_data.append(row)
127
+ return batch_data
128
+ except Exception as e:
129
+ logger.error(f"Error loading batch from {batch_file}: {e}")
130
+ raise
131
+
132
+ def _load_current_batch(self):
133
+ """Load the current batch into memory."""
134
+ if hasattr(self, "current_batch_data") and self.current_batch_data is not None:
135
+ del self.current_batch_data
136
+ batch_file = self.batch_files[self.current_batch_idx]
137
+ self.current_batch_data = self._load_batch_from_file(batch_file)
138
+ self.current_sample_idx = 0
139
+ logger.debug(
140
+ f"Loaded batch {self.current_batch_idx} for {self.generator_type} "
141
+ f"with {len(self.current_batch_data)} samples"
142
+ )
143
+
144
+ def _trigger_smart_prefetch(self):
145
+ """Trigger prefetching when batch is almost exhausted."""
146
+ if not self.prefetch_next or len(self.batch_files) <= 1:
147
+ return
148
+ remaining_samples = self.get_remaining_samples_in_current_batch()
149
+ should_prefetch = (
150
+ remaining_samples <= self.prefetch_threshold
151
+ and self.next_batch_data is None
152
+ and not self.prefetching_in_progress
153
+ )
154
+ if should_prefetch:
155
+ self._prefetch_next_batch()
156
+
157
+ def _prefetch_next_batch(self):
158
+ """Prefetch the next batch."""
159
+ if self.prefetching_in_progress:
160
+ return
161
+ self.prefetching_in_progress = True
162
+ next_batch_idx = (self.current_batch_idx + 1) % len(self.batch_files)
163
+ next_batch_file = self.batch_files[next_batch_idx]
164
+ try:
165
+ self.next_batch_data = self._load_batch_from_file(next_batch_file)
166
+ logger.debug(
167
+ f"Prefetched next batch {next_batch_idx} for {self.generator_type}"
168
+ )
169
+ except Exception as e:
170
+ logger.warning(f"Failed to prefetch batch {next_batch_idx}: {e}")
171
+ self.next_batch_data = None
172
+ finally:
173
+ self.prefetching_in_progress = False
174
+
175
+ def _advance_to_next_batch(self):
176
+ """Advance to the next batch and log the transition."""
177
+ if hasattr(self, "current_batch_data") and self.current_batch_data is not None:
178
+ del self.current_batch_data
179
+
180
+ previous_batch_idx = self.current_batch_idx
181
+ self.current_batch_idx = (self.current_batch_idx + 1) % len(self.batch_files)
182
+
183
+ if hasattr(self, "next_batch_data") and self.next_batch_data is not None:
184
+ self.current_batch_data = self.next_batch_data
185
+ self.next_batch_data = None
186
+ else:
187
+ self._load_current_batch()
188
+
189
+ self.current_sample_idx = 0
190
+ self.prefetching_in_progress = False
191
+
192
+ # --- NEW: Enhanced Logging Logic ---
193
+ self.visited_batch_indices.add(self.current_batch_idx)
194
+
195
+ # Calculate progress
196
+ total_files = len(self.batch_files)
197
+ visited_count = len(self.visited_batch_indices)
198
+ progress_percent = (visited_count / total_files) * 100
199
+
200
+ # Log the shard cycle event
201
+ logger.info(
202
+ f"\nDATA SHARD CYCLED for '{self.generator_type}': "
203
+ f"Moved from file index {previous_batch_idx} to {self.current_batch_idx}. "
204
+ f"Unique files visited: {visited_count}/{total_files} ({progress_percent:.1f}%)."
205
+ )
206
+
207
+ # Check if a full cycle has been completed
208
+ if visited_count == total_files:
209
+ self.full_cycles_completed += 1
210
+ logger.info(
211
+ f"🎉 FULL CYCLE #{self.full_cycles_completed} COMPLETED for '{self.generator_type}'! "
212
+ f"All {total_files} data files have been visited at least once. "
213
+ "Resetting visited set to track the next cycle."
214
+ )
215
+ # Reset for the next cycle count
216
+ self.visited_batch_indices.clear()
217
+ self.visited_batch_indices.add(self.current_batch_idx)
218
+
219
+ def get_sample(self) -> dict:
220
+ """Get the current sample and advance pointer."""
221
+ if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
222
+ self._load_current_batch()
223
+ if self.current_batch_data is None:
224
+ raise RuntimeError("No batch data loaded")
225
+ if self.current_sample_idx >= len(self.current_batch_data):
226
+ self._advance_to_next_batch()
227
+ self._trigger_smart_prefetch()
228
+ sample = self.current_batch_data[self.current_sample_idx]
229
+ self.current_sample_idx += 1
230
+ return sample
231
+
232
+ def get_samples(self, num_samples: int) -> List[dict]:
233
+ """Get multiple samples."""
234
+ samples = []
235
+ for _ in range(num_samples):
236
+ samples.append(self.get_sample())
237
+ return samples
238
+
239
+ def get_total_samples_in_current_batch(self) -> int:
240
+ """Get total samples in current batch."""
241
+ if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
242
+ return 0
243
+ return len(self.current_batch_data)
244
+
245
+ def get_remaining_samples_in_current_batch(self) -> int:
246
+ """Get remaining samples in current batch."""
247
+ if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
248
+ return 0
249
+ return max(0, len(self.current_batch_data) - self.current_sample_idx)
250
+
251
+ def get_info(self) -> dict:
252
+ """Get extended dataset info, including cycle progress."""
253
+ total_files = len(self.batch_files)
254
+ visited_count = len(self.visited_batch_indices)
255
+ return {
256
+ "generator_type": self.generator_type,
257
+ "total_batch_files": total_files,
258
+ "current_batch_idx": self.current_batch_idx,
259
+ "current_sample_idx": self.current_sample_idx,
260
+ "current_batch_size": self.get_total_samples_in_current_batch(),
261
+ "remaining_in_batch": self.get_remaining_samples_in_current_batch(),
262
+ "unique_files_visited": visited_count,
263
+ "cycle_progress_percent": (visited_count / total_files) * 100
264
+ if total_files > 0
265
+ else 0,
266
+ "full_cycles_completed": self.full_cycles_completed,
267
+ }
src/data/filter.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from scipy import signal
4
+ from statsmodels.tsa.stattools import acf
5
+
6
+
7
+ def lempel_ziv_complexity(binary_sequence: np.ndarray) -> int:
8
+ """Computes the Lempel-Ziv complexity of a binary sequence."""
9
+ sub_strings = set()
10
+ n = len(binary_sequence)
11
+ i = 0
12
+ count = 0
13
+ while i < n:
14
+ sub_str = ""
15
+ for j in range(i, n):
16
+ sub_str += str(binary_sequence[j])
17
+ if sub_str not in sub_strings:
18
+ sub_strings.add(sub_str)
19
+ count += 1
20
+ i = j + 1
21
+ break
22
+ else:
23
+ i += 1
24
+ return count
25
+
26
+
27
+ def is_low_quality(
28
+ series: torch.Tensor,
29
+ autocorr_threshold: float = 0.2,
30
+ snr_threshold: float = 0.5,
31
+ complexity_threshold: float = 0.4,
32
+ ) -> bool:
33
+ """
34
+ Returns True if the series appears non-forecastable (noise-like):
35
+ - weak autocorrelation
36
+ - low SNR proxy
37
+ - high normalized Lempel-Ziv complexity
38
+ """
39
+ x = series.squeeze().detach().cpu().numpy()
40
+ if x.size < 20:
41
+ return True
42
+ if np.var(x) < 1e-10:
43
+ return True
44
+
45
+ x_detrended = signal.detrend(x)
46
+
47
+ try:
48
+ max_lags = min(len(x_detrended) // 4, 40)
49
+ if max_lags < 1:
50
+ autocorr_strength = 0.0
51
+ else:
52
+ acf_vals = acf(x_detrended, nlags=max_lags, fft=True)[1:]
53
+ autocorr_strength = float(np.max(np.abs(acf_vals)))
54
+ except Exception:
55
+ autocorr_strength = 0.0
56
+
57
+ win_size = max(3, min(len(x) // 10, 15))
58
+ signal_est = np.convolve(x, np.ones(win_size) / win_size, mode="valid")
59
+ noise_est = x[win_size - 1 :] - signal_est
60
+ var_signal = float(np.var(signal_est))
61
+ var_noise = float(np.var(noise_est))
62
+ snr_proxy = var_signal / var_noise if var_noise > 1e-8 else 1.0
63
+
64
+ median_val = float(np.median(x_detrended))
65
+ binary_seq = (x_detrended > median_val).astype(np.uint8)
66
+ complexity_score = lempel_ziv_complexity(binary_seq)
67
+ normalized_complexity = complexity_score / max(1, len(binary_seq))
68
+
69
+ is_random_like = (snr_proxy < snr_threshold) and (
70
+ normalized_complexity > complexity_threshold
71
+ )
72
+ is_uncorrelated = autocorr_strength < autocorr_threshold
73
+ return bool(is_uncorrelated and is_random_like)
src/data/frequency.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive frequency management module for time series forecasting.
3
+
4
+ This module centralizes all frequency-related functionality including:
5
+ - Frequency enum with helper methods
6
+ - Frequency parsing and validation
7
+ - Pandas frequency string conversion
8
+ - Safety checks for date ranges
9
+ - Frequency selection utilities
10
+ - All frequency constants and mappings
11
+ """
12
+
13
+ import logging
14
+ import re
15
+ from enum import Enum
16
+ from typing import Dict, Tuple
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ from numpy.random import Generator
21
+
22
+ from src.data.constants import BASE_END_DATE, BASE_START_DATE, MAX_YEARS
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class Frequency(Enum):
28
+ """
29
+ Enhanced Frequency enum with comprehensive helper methods.
30
+
31
+ Each frequency includes methods for pandas conversion, safety checks,
32
+ and other frequency-specific operations.
33
+ """
34
+
35
+ A = "A" # Annual
36
+ Q = "Q" # Quarterly
37
+ M = "M" # Monthly
38
+ W = "W" # Weekly
39
+ D = "D" # Daily
40
+ H = "h" # Hourly
41
+ S = "s" # Seconds
42
+ T1 = "1min" # 1 minute
43
+ T5 = "5min" # 5 minutes
44
+ T10 = "10min" # 10 minutes
45
+ T15 = "15min" # 15 minutes
46
+ T30 = "30min" # 30 minutes
47
+
48
+ def to_pandas_freq(self, for_date_range: bool = True) -> str:
49
+ """
50
+ Convert to pandas frequency string.
51
+
52
+ Args:
53
+ for_date_range: If True, use strings suitable for pd.date_range().
54
+ If False, use strings suitable for pd.PeriodIndex().
55
+
56
+ Returns:
57
+ Pandas frequency string
58
+ """
59
+ base, prefix, _ = FREQUENCY_MAPPING[self]
60
+
61
+ # Special handling for date_range vs period compatibility
62
+ if for_date_range:
63
+ # For date_range, use modern pandas frequency strings
64
+ if self == Frequency.M:
65
+ return "ME" # Month End
66
+ elif self == Frequency.A:
67
+ return "YE" # Year End
68
+ elif self == Frequency.Q:
69
+ return "QE" # Quarter End
70
+ else:
71
+ # For periods, use legacy frequency strings
72
+ if self == Frequency.M:
73
+ return "M" # Month for periods
74
+ elif self == Frequency.A:
75
+ return "Y" # Year for periods (not YE)
76
+ elif self == Frequency.Q:
77
+ return "Q" # Quarter for periods (not QE)
78
+
79
+ # Construct frequency string for other frequencies
80
+ if prefix:
81
+ return f"{prefix}{base}"
82
+ else:
83
+ return base
84
+
85
+ def to_pandas_offset(self) -> str:
86
+ """Get pandas offset string for time delta calculations."""
87
+ return FREQUENCY_TO_OFFSET[self]
88
+
89
+ def get_days_per_period(self) -> float:
90
+ """Get approximate days per period for this frequency."""
91
+ _, _, days = FREQUENCY_MAPPING[self]
92
+ return days
93
+
94
+ def get_max_safe_length(self) -> int:
95
+ """Get maximum safe sequence length to prevent timestamp overflow."""
96
+ return ALL_FREQUENCY_MAX_LENGTHS.get(self, float("inf"))
97
+
98
+ def is_high_frequency(self) -> bool:
99
+ """Check if this is a high frequency (minute/second level)."""
100
+ return self in [
101
+ Frequency.S,
102
+ Frequency.T1,
103
+ Frequency.T5,
104
+ Frequency.T10,
105
+ Frequency.T15,
106
+ Frequency.T30,
107
+ ]
108
+
109
+ def is_low_frequency(self) -> bool:
110
+ """Check if this is a low frequency (annual/quarterly/monthly)."""
111
+ return self in [Frequency.A, Frequency.Q, Frequency.M]
112
+
113
+ def get_seasonality(self) -> int:
114
+ """Get typical seasonality for this frequency."""
115
+ seasonality_map = {
116
+ Frequency.S: 3600, # 1 hour of seconds
117
+ Frequency.T1: 60, # 1 hour of minutes
118
+ Frequency.T5: 12, # 1 hour of 5-minute intervals
119
+ Frequency.T10: 6, # 1 hour of 10-minute intervals
120
+ Frequency.T15: 4, # 1 hour of 15-minute intervals
121
+ Frequency.T30: 2, # 1 hour of 30-minute intervals
122
+ Frequency.H: 24, # 1 day of hours
123
+ Frequency.D: 7, # 1 week of days
124
+ Frequency.W: 52, # 1 year of weeks
125
+ Frequency.M: 12, # 1 year of months
126
+ Frequency.Q: 4, # 1 year of quarters
127
+ Frequency.A: 1, # No clear seasonality for annual
128
+ }
129
+ return seasonality_map.get(self, 1)
130
+
131
+ def get_gift_eval_weight(self) -> float:
132
+ """Get GIFT eval dataset frequency weight."""
133
+ return GIFT_EVAL_FREQUENCY_WEIGHTS.get(self, 0.1)
134
+
135
+ def get_length_range(self) -> Tuple[int, int, int, int]:
136
+ """Get (min_length, max_length, optimal_start, optimal_end) for this frequency."""
137
+ return GIFT_EVAL_LENGTH_RANGES.get(self, (50, 1000, 100, 500))
138
+
139
+
140
+ # ============================================================================
141
+ # Frequency Mappings and Constants
142
+ # ============================================================================
143
+
144
+ # Core frequency mapping: (pandas_base, prefix, days_per_period)
145
+ FREQUENCY_MAPPING: Dict[Frequency, Tuple[str, str, float]] = {
146
+ Frequency.A: (
147
+ "YE",
148
+ "",
149
+ 365.25,
150
+ ), # Average days per year (accounting for leap years)
151
+ Frequency.Q: ("Q", "", 91.3125), # 365.25/4 - average days per quarter
152
+ Frequency.M: ("M", "", 30.4375), # 365.25/12 - average days per month
153
+ Frequency.W: ("W", "", 7),
154
+ Frequency.D: ("D", "", 1),
155
+ Frequency.H: ("h", "", 1 / 24),
156
+ Frequency.S: ("s", "", 1 / 86400), # 24*60*60
157
+ Frequency.T1: ("min", "1", 1 / 1440), # 24*60
158
+ Frequency.T5: ("min", "5", 1 / 288), # 24*60/5
159
+ Frequency.T10: ("min", "10", 1 / 144), # 24*60/10
160
+ Frequency.T15: ("min", "15", 1 / 96), # 24*60/15
161
+ Frequency.T30: ("min", "30", 1 / 48), # 24*60/30
162
+ }
163
+
164
+ # Frequency to pandas offset mapping for calculating time deltas
165
+ FREQUENCY_TO_OFFSET: Dict[Frequency, str] = {
166
+ Frequency.A: "AS", # Annual start
167
+ Frequency.Q: "QS", # Quarter start
168
+ Frequency.M: "MS", # Month start
169
+ Frequency.W: "W", # Weekly
170
+ Frequency.D: "D", # Daily
171
+ Frequency.H: "H", # Hourly
172
+ Frequency.T1: "1T", # 1 minute
173
+ Frequency.T5: "5T", # 5 minutes
174
+ Frequency.T10: "10T", # 10 minutes
175
+ Frequency.T15: "15T", # 15 minutes
176
+ Frequency.T30: "30T", # 30 minutes
177
+ Frequency.S: "S", # Seconds
178
+ }
179
+
180
+ # Maximum sequence lengths to avoid pandas OutOfBoundsDatetime errors
181
+ SHORT_FREQUENCY_MAX_LENGTHS = {
182
+ Frequency.A: MAX_YEARS,
183
+ Frequency.Q: MAX_YEARS * 4,
184
+ Frequency.M: MAX_YEARS * 12,
185
+ Frequency.W: int(MAX_YEARS * 52.1775),
186
+ Frequency.D: int(MAX_YEARS * 365.2425),
187
+ }
188
+
189
+ HIGH_FREQUENCY_MAX_LENGTHS = {
190
+ Frequency.H: int(MAX_YEARS * 365.2425 * 24),
191
+ Frequency.S: int(MAX_YEARS * 365.2425 * 24 * 60 * 60),
192
+ Frequency.T1: int(MAX_YEARS * 365.2425 * 24 * 60),
193
+ Frequency.T5: int(MAX_YEARS * 365.2425 * 24 * 12),
194
+ Frequency.T10: int(MAX_YEARS * 365.2425 * 24 * 6),
195
+ Frequency.T15: int(MAX_YEARS * 365.2425 * 24 * 4),
196
+ Frequency.T30: int(MAX_YEARS * 365.2425 * 24 * 2),
197
+ }
198
+
199
+ # Combined max lengths for all frequencies
200
+ ALL_FREQUENCY_MAX_LENGTHS = {
201
+ **SHORT_FREQUENCY_MAX_LENGTHS,
202
+ **HIGH_FREQUENCY_MAX_LENGTHS,
203
+ }
204
+
205
+ # GIFT eval-based frequency weights from actual dataset analysis
206
+ GIFT_EVAL_FREQUENCY_WEIGHTS: Dict[Frequency, float] = {
207
+ Frequency.H: 25.0, # Hourly - most common
208
+ Frequency.D: 23.4, # Daily - second most common
209
+ Frequency.W: 12.9, # Weekly - third most common
210
+ Frequency.T15: 9.7, # 15-minute
211
+ Frequency.T5: 9.7, # 5-minute
212
+ Frequency.M: 7.3, # Monthly
213
+ Frequency.T10: 4.8, # 10-minute
214
+ Frequency.S: 4.8, # 10-second
215
+ Frequency.T1: 1.6, # 1-minute
216
+ Frequency.Q: 0.8, # Quarterly
217
+ Frequency.A: 0.8, # Annual
218
+ }
219
+
220
+ # GIFT eval-based length ranges derived from actual dataset analysis
221
+ # Format: (min_length, max_length, optimal_start, optimal_end)
222
+ GIFT_EVAL_LENGTH_RANGES: Dict[Frequency, Tuple[int, int, int, int]] = {
223
+ # Low frequency ranges (based on actual GIFT eval data + logical extensions)
224
+ Frequency.A: (25, 100, 30, 70),
225
+ Frequency.Q: (25, 150, 50, 120),
226
+ Frequency.M: (40, 1000, 100, 600),
227
+ Frequency.W: (50, 3500, 100, 1500),
228
+ # Medium frequency ranges
229
+ Frequency.D: (150, 25000, 300, 7000), # Daily: covers 1-year+ scenarios
230
+ Frequency.H: (600, 35000, 700, 17000),
231
+ # High frequency ranges (extended for shorter realistic scenarios)
232
+ Frequency.T1: (200, 2500, 1200, 1800), # 1-minute: day to few days
233
+ Frequency.S: (7500, 9500, 7900, 9000),
234
+ Frequency.T15: (1000, 140000, 50000, 130000),
235
+ Frequency.T5: (200, 105000, 20000, 95000),
236
+ Frequency.T10: (40000, 55000, 47000, 52000),
237
+ Frequency.T30: (100, 50000, 10000, 40000),
238
+ }
239
+
240
+
241
+ # ============================================================================
242
+ # Frequency Parsing and Validation
243
+ # ============================================================================
244
+
245
+
246
+ def parse_frequency(freq_str: str) -> Frequency:
247
+ """
248
+ Parse frequency string to Frequency enum, robust to variations.
249
+
250
+ Handles various frequency string formats:
251
+ - Standard: "A", "Q", "M", "W", "D", "H", "S"
252
+ - Pandas-style: "A-DEC", "W-SUN", "QE-MAR"
253
+ - Minutes: "5T", "10min", "1T"
254
+ - Case variations: "a", "h", "D"
255
+
256
+ Args:
257
+ freq_str: The frequency string to parse (e.g., "5T", "W-SUN", "M")
258
+
259
+ Returns:
260
+ Corresponding Frequency enum member
261
+
262
+ Raises:
263
+ ValueError: If the frequency string is not supported
264
+ """
265
+ # Handle minute-based frequencies BEFORE pandas standardization
266
+ # because pandas converts "5T" to just "min", losing the multiplier
267
+ minute_match = re.match(r"^(\d*)T$", freq_str, re.IGNORECASE) or re.match(
268
+ r"^(\d*)min$", freq_str, re.IGNORECASE
269
+ )
270
+ if minute_match:
271
+ multiplier = int(minute_match.group(1)) if minute_match.group(1) else 1
272
+ enum_key = f"T{multiplier}"
273
+ try:
274
+ return Frequency[enum_key]
275
+ except KeyError:
276
+ logger.warning(
277
+ f"Unsupported minute frequency '{freq_str}' (multiplier: {multiplier}). "
278
+ f"Falling back to '1min' ({Frequency.T1.value})."
279
+ )
280
+ return Frequency.T1
281
+
282
+ # Now standardize frequency string for other cases
283
+ try:
284
+ offset = pd.tseries.frequencies.to_offset(freq_str)
285
+ standardized_freq = offset.name
286
+ except Exception:
287
+ standardized_freq = freq_str
288
+
289
+ # Handle other frequencies by their base (e.g., 'W-SUN' -> 'W', 'A-DEC' -> 'A')
290
+ base_freq = standardized_freq.split("-")[0].upper()
291
+
292
+ freq_map = {
293
+ "A": Frequency.A,
294
+ "Y": Frequency.A, # Alias for Annual
295
+ "YE": Frequency.A, # Alias for Annual
296
+ "Q": Frequency.Q,
297
+ "QE": Frequency.Q, # Alias for Quarterly
298
+ "M": Frequency.M,
299
+ "ME": Frequency.M, # Alias for Monthly
300
+ "W": Frequency.W,
301
+ "D": Frequency.D,
302
+ "H": Frequency.H,
303
+ "S": Frequency.S,
304
+ }
305
+
306
+ if base_freq in freq_map:
307
+ return freq_map[base_freq]
308
+
309
+ raise NotImplementedError(f"Frequency '{standardized_freq}' is not supported.")
310
+
311
+
312
+ def validate_frequency_safety(
313
+ start_date: np.datetime64, total_length: int, frequency: Frequency
314
+ ) -> bool:
315
+ """
316
+ Check if start date and frequency combination is safe for pandas datetime operations.
317
+
318
+ This function verifies that pd.date_range(start=start_date, periods=total_length, freq=freq_str)
319
+ will not raise an OutOfBoundsDatetime error, accounting for pandas' datetime bounds
320
+ (1677-09-21 to 2262-04-11) and realistic frequency limitations.
321
+
322
+ Args:
323
+ start_date: The proposed start date for the time series
324
+ total_length: Total length of the time series
325
+ frequency: The frequency of the time series
326
+
327
+ Returns:
328
+ True if the combination is safe, False otherwise
329
+ """
330
+ try:
331
+ # Get the pandas frequency string
332
+ freq_str = frequency.to_pandas_freq(for_date_range=True)
333
+
334
+ # Convert numpy datetime64 to pandas Timestamp for date_range
335
+ start_pd = pd.Timestamp(start_date)
336
+
337
+ # Check if start date is within pandas' valid datetime range
338
+ if start_pd < pd.Timestamp.min or start_pd > pd.Timestamp.max:
339
+ return False
340
+
341
+ # Check maximum length constraints
342
+ max_length = frequency.get_max_safe_length()
343
+ if total_length > max_length:
344
+ return False
345
+
346
+ # For low frequencies, be extra conservative
347
+ if frequency.is_low_frequency():
348
+ if frequency == Frequency.A and total_length > 500: # Max ~500 years
349
+ return False
350
+ elif frequency == Frequency.Q and total_length > 2000: # Max ~500 years
351
+ return False
352
+ elif frequency == Frequency.M and total_length > 6000: # Max ~500 years
353
+ return False
354
+
355
+ # Calculate approximate end date
356
+ days_per_period = frequency.get_days_per_period()
357
+ approx_days = total_length * days_per_period
358
+
359
+ # For annual/quarterly frequencies, add extra safety margin
360
+ if frequency in [Frequency.A, Frequency.Q]:
361
+ approx_days *= 1.1 # 10% safety margin
362
+
363
+ end_date = start_pd + pd.Timedelta(days=approx_days)
364
+
365
+ # Check if end date is within pandas' valid datetime range
366
+ if end_date < pd.Timestamp.min or end_date > pd.Timestamp.max:
367
+ return False
368
+
369
+ # Try to create the date range as final validation
370
+ pd.date_range(start=start_pd, periods=total_length, freq=freq_str)
371
+ return True
372
+
373
+ except (pd.errors.OutOfBoundsDatetime, OverflowError, ValueError):
374
+ return False
375
+
376
+
377
+ # ============================================================================
378
+ # Frequency Selection Utilities
379
+ # ============================================================================
380
+
381
+
382
+ def select_safe_random_frequency(total_length: int, rng: Generator) -> Frequency:
383
+ """
384
+ Select a random frequency suitable for a given total length of a time series,
385
+ based on actual GIFT eval dataset patterns and distributions.
386
+
387
+ The selection logic:
388
+ 1. Filters frequencies that can handle the given total_length
389
+ 2. Applies base weights derived from actual GIFT eval frequency distribution
390
+ 3. Strongly boosts frequencies that are in their optimal length ranges
391
+ 4. Handles edge cases gracefully with fallbacks
392
+
393
+ Args:
394
+ total_length: The total length of the time series (history + future)
395
+ rng: A numpy random number generator instance
396
+
397
+ Returns:
398
+ A randomly selected frequency that matches GIFT eval patterns
399
+ """
400
+ # Find valid frequencies and calculate weighted scores
401
+ valid_frequencies = []
402
+ frequency_scores = []
403
+
404
+ for freq in Frequency:
405
+ # Check basic timestamp overflow limits
406
+ max_allowed = freq.get_max_safe_length()
407
+ if total_length > max_allowed:
408
+ continue
409
+
410
+ # Check if frequency has defined ranges
411
+ min_len, max_len, optimal_start, optimal_end = freq.get_length_range()
412
+
413
+ # Must be within the frequency's realistic range
414
+ if total_length < min_len or total_length > max_len:
415
+ continue
416
+
417
+ valid_frequencies.append(freq)
418
+
419
+ # Calculate fitness score based on GIFT eval patterns
420
+ base_weight = freq.get_gift_eval_weight()
421
+
422
+ # Enhanced length-based fitness scoring
423
+ if optimal_start <= total_length <= optimal_end:
424
+ # In optimal range - very strong preference
425
+ length_multiplier = 5.0
426
+ else:
427
+ # Outside optimal but within valid range - calculate penalty
428
+ if total_length < optimal_start:
429
+ # Below optimal range
430
+ distance_ratio = (optimal_start - total_length) / (
431
+ optimal_start - min_len
432
+ )
433
+ else:
434
+ # Above optimal range
435
+ distance_ratio = (total_length - optimal_end) / (max_len - optimal_end)
436
+
437
+ # Apply graduated penalty: closer to optimal = higher score
438
+ length_multiplier = 0.3 + 1.2 * (1.0 - distance_ratio) # Range: 0.3-1.5
439
+
440
+ final_score = base_weight * length_multiplier
441
+ frequency_scores.append(final_score)
442
+
443
+ # Handle edge cases with smart fallbacks
444
+ if not valid_frequencies:
445
+ # Fallback strategy based on typical length patterns
446
+ if total_length <= 100:
447
+ # Very short series - prefer low frequencies
448
+ fallback_order = [
449
+ Frequency.A,
450
+ Frequency.Q,
451
+ Frequency.M,
452
+ Frequency.W,
453
+ Frequency.D,
454
+ ]
455
+ elif total_length <= 1000:
456
+ # Medium short series - prefer daily/weekly
457
+ fallback_order = [Frequency.D, Frequency.W, Frequency.H, Frequency.M]
458
+ else:
459
+ # Longer series - prefer higher frequencies
460
+ fallback_order = [Frequency.H, Frequency.D, Frequency.T15, Frequency.T5]
461
+
462
+ for fallback_freq in fallback_order:
463
+ max_allowed = fallback_freq.get_max_safe_length()
464
+ if total_length <= max_allowed:
465
+ return fallback_freq
466
+ # Last resort
467
+ return Frequency.D
468
+
469
+ if len(valid_frequencies) == 1:
470
+ return valid_frequencies[0]
471
+
472
+ # Select based on weighted probabilities
473
+ scores = np.array(frequency_scores)
474
+ probabilities = scores / scores.sum()
475
+
476
+ return rng.choice(valid_frequencies, p=probabilities)
477
+
478
+
479
+ def select_safe_start_date(
480
+ total_length: int,
481
+ frequency: Frequency,
482
+ rng: Generator = np.random.default_rng(),
483
+ max_retries: int = 10,
484
+ ) -> np.datetime64:
485
+ """
486
+ Select a safe start date that ensures the entire time series (history + future)
487
+ will not exceed pandas' datetime bounds.
488
+
489
+ Args:
490
+ total_length: Total length of the time series (history + future)
491
+ frequency: Time series frequency
492
+ rng: Random number generator instance
493
+ max_retries: Maximum number of retry attempts
494
+
495
+ Returns:
496
+ A safe start date that prevents timestamp overflow
497
+
498
+ Raises:
499
+ ValueError: If no safe start date is found after max_retries or if the required
500
+ time span exceeds the available date window
501
+ """
502
+ days_per_period = frequency.get_days_per_period()
503
+
504
+ # Calculate approximate duration in days
505
+ total_days = total_length * days_per_period
506
+
507
+ # Define safe bounds: ensure end date doesn't exceed BASE_END_DATE
508
+ latest_safe_start = BASE_END_DATE - np.timedelta64(int(total_days), "D")
509
+ earliest_safe_start = BASE_START_DATE
510
+
511
+ # Check if the required time span exceeds the available window
512
+ if latest_safe_start < earliest_safe_start:
513
+ available_days = (
514
+ (BASE_END_DATE - BASE_START_DATE).astype("timedelta64[D]").astype(int)
515
+ )
516
+ available_years = available_days / 365.25
517
+ required_years = total_days / 365.25
518
+ raise ValueError(
519
+ f"Required time span ({required_years:.1f} years, {total_days:.0f} days) "
520
+ f"exceeds available date window ({available_years:.1f} years, {available_days} days). "
521
+ f"Reduce total_length ({total_length}) or extend the date window."
522
+ )
523
+
524
+ # Convert to nanoseconds for random sampling
525
+ earliest_ns = earliest_safe_start.astype("datetime64[ns]").astype(np.int64)
526
+ latest_ns = latest_safe_start.astype("datetime64[ns]").astype(np.int64)
527
+
528
+ for _ in range(max_retries):
529
+ # Uniformly sample a start date within bounds
530
+ random_ns = rng.integers(earliest_ns, latest_ns + 1)
531
+ start_date = np.datetime64(int(random_ns), "ns")
532
+
533
+ # Verify safety
534
+ if validate_frequency_safety(start_date, total_length, frequency):
535
+ return start_date
536
+
537
+ # Default to base start date if no safe start date is found
538
+ return BASE_START_DATE
src/data/loaders.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import Dict, Iterator, List, Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+
9
+ from src.data.batch_composer import BatchComposer, ComposedDataset
10
+ from src.data.containers import BatchTimeSeriesContainer
11
+ from src.data.frequency import parse_frequency
12
+ from src.gift_eval.constants import ALL_DATASETS
13
+ from src.gift_eval.data import Dataset as GiftEvalDataset
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class GiftEvalDataLoader:
19
+ """
20
+ Data loader for GIFT-eval datasets, converting them to BatchTimeSeriesContainer format.
21
+ Supports both training and validation modes.
22
+ """
23
+
24
+ TERMS = ["short", "medium", "long"]
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = "train",
29
+ batch_size: int = 32,
30
+ device: Optional[torch.device] = None,
31
+ shuffle: bool = True,
32
+ to_univariate: bool = False,
33
+ max_context_length: Optional[int] = None,
34
+ max_windows: int = 20,
35
+ skip_datasets_with_nans: bool = False,
36
+ datasets_to_use: Optional[List[str]] = None,
37
+ dataset_storage_path: Optional[str] = None,
38
+ ):
39
+ """
40
+ Initialize GIFT-eval data loader.
41
+
42
+ Args:
43
+ mode: Either "train" or "validation"
44
+ batch_size: Number of samples per batch
45
+ device: Device to load data to
46
+ shuffle: Whether to shuffle data
47
+ to_univariate: Whether to convert multivariate data to multiple univariate series
48
+ max_context_length: Optional maximum total window length (context + forecast) to prevent memory issues
49
+ max_windows: Number of windows to use for training/validation
50
+ skip_datasets_with_nans: Whether to skip datasets/series that contain NaN values
51
+ datasets_to_use: Optional list of dataset names to use. If None, uses all available datasets
52
+ dataset_storage_path: Path on disk where GIFT-eval HuggingFace datasets are stored
53
+ """
54
+ # Use specified datasets or all available datasets if none specified
55
+ if datasets_to_use is not None and len(datasets_to_use) > 0:
56
+ # Validate that requested datasets are available
57
+ invalid_datasets = [ds for ds in datasets_to_use if ds not in ALL_DATASETS]
58
+ if invalid_datasets:
59
+ logger.warning(f"Invalid datasets requested: {invalid_datasets}")
60
+ logger.warning(f"Available datasets: {ALL_DATASETS}")
61
+ # Use only valid datasets
62
+ self.dataset_names = [
63
+ ds for ds in datasets_to_use if ds in ALL_DATASETS
64
+ ]
65
+ else:
66
+ self.dataset_names = datasets_to_use
67
+ else:
68
+ self.dataset_names = ALL_DATASETS
69
+
70
+ # Log dataset selection
71
+ if datasets_to_use is not None and len(datasets_to_use) > 0:
72
+ logger.info(
73
+ f"Using subset of datasets: {len(self.dataset_names)}/{len(ALL_DATASETS)} datasets"
74
+ )
75
+ logger.info(f"Selected datasets: {self.dataset_names}")
76
+ else:
77
+ logger.info(
78
+ f"Using all available datasets: {len(self.dataset_names)} datasets"
79
+ )
80
+
81
+ self.terms = self.TERMS
82
+ self.mode = mode
83
+ self.batch_size = batch_size
84
+ self.device = device
85
+ self.shuffle = shuffle
86
+ self.to_univariate = to_univariate
87
+ self.max_context_length = max_context_length
88
+ self.skip_datasets_with_nans = skip_datasets_with_nans
89
+
90
+ # Window configuration based on mode
91
+ self.max_windows = max_windows
92
+ self.dataset_storage_path = dataset_storage_path
93
+
94
+ # Load all datasets and prepare data
95
+ self._load_datasets()
96
+
97
+ # Create iterator state
98
+ self._current_idx = 0
99
+ self._epoch_data = []
100
+ self._prepare_epoch_data()
101
+
102
+ def _load_datasets(self) -> None:
103
+ """Load all specified GIFT-eval datasets."""
104
+ self.datasets = {}
105
+ self.dataset_prediction_lengths = {}
106
+
107
+ for dataset_name in self.dataset_names:
108
+ if dataset_name.startswith("m4_"):
109
+ max_windows = 1
110
+ else:
111
+ max_windows = self.max_windows
112
+ try:
113
+ # Determine if we need univariate conversion
114
+ # First check with multivariate to see target dimension
115
+ temp_dataset = GiftEvalDataset(
116
+ name=dataset_name,
117
+ term=self.terms[0], # Use first term to check dimensionality
118
+ to_univariate=False,
119
+ max_windows=max_windows,
120
+ storage_path=self.dataset_storage_path,
121
+ )
122
+
123
+ # Convert to univariate if needed
124
+ to_univariate = self.to_univariate and temp_dataset.target_dim > 1
125
+
126
+ # Load datasets for all terms
127
+ for term in self.terms:
128
+ dataset_key = f"{dataset_name}_{term}"
129
+ dataset = GiftEvalDataset(
130
+ name=dataset_name,
131
+ term=term,
132
+ to_univariate=to_univariate,
133
+ max_windows=max_windows,
134
+ storage_path=self.dataset_storage_path,
135
+ )
136
+
137
+ self.datasets[dataset_key] = dataset
138
+ self.dataset_prediction_lengths[dataset_key] = (
139
+ dataset.prediction_length
140
+ )
141
+
142
+ logger.info(
143
+ f"Loaded {dataset_key} - prediction_length: {dataset.prediction_length}, "
144
+ f"frequency: {dataset.freq}, target_dim: {dataset.target_dim}, "
145
+ f"min_length: {dataset._min_series_length}, windows: {dataset.windows}"
146
+ )
147
+
148
+ except Exception as e:
149
+ logger.warning(f"Failed to load dataset {dataset_name}: {str(e)}")
150
+ continue
151
+
152
+ def _contains_nan(self, data_entry: dict) -> bool:
153
+ """Check if a data entry contains NaN values."""
154
+ target = data_entry.get("target")
155
+ if target is None:
156
+ return False
157
+
158
+ # Convert to numeric numpy array for robust NaN checking
159
+ try:
160
+ target_np = np.asarray(target, dtype=np.float32)
161
+ return np.isnan(target_np).any()
162
+ except Exception:
163
+ logger.warning(
164
+ "NaN check: failed to coerce target to float32; skipping entry"
165
+ )
166
+ return True
167
+
168
+ def _convert_to_container(
169
+ self, data_entries: List[dict], prediction_length: int, dataset_freq: str
170
+ ) -> BatchTimeSeriesContainer:
171
+ """Convert a batch of data entries to BatchTimeSeriesContainer format with fixed future length."""
172
+ batch_size = len(data_entries)
173
+ max_history_len = 0
174
+
175
+ # First pass: determine max history length after truncation
176
+ for entry in data_entries:
177
+ target = np.asarray(entry["target"], dtype=np.float32)
178
+ if target.ndim == 1:
179
+ target = target.reshape(1, -1)
180
+
181
+ _, seq_len = target.shape
182
+
183
+ # Only consider up to the last (max_context_length) values
184
+ effective_max_context = (
185
+ self.max_context_length
186
+ if self.max_context_length is not None
187
+ else seq_len
188
+ )
189
+ if seq_len > effective_max_context:
190
+ seq_len = effective_max_context
191
+
192
+ # History is up to (max_context_length - prediction_length)
193
+ history_len = max(
194
+ 0, min(seq_len, effective_max_context) - prediction_length
195
+ )
196
+ max_history_len = max(max_history_len, history_len)
197
+
198
+ # Get number of channels from first entry
199
+ first_target = np.asarray(data_entries[0]["target"], dtype=np.float32)
200
+ if first_target.ndim == 1:
201
+ # Shape to [channels, time]
202
+ first_target = first_target.reshape(1, -1)
203
+ num_channels = first_target.shape[0]
204
+
205
+ # Allocate arrays
206
+ history_values = np.full(
207
+ (batch_size, max_history_len, num_channels), np.nan, dtype=np.float32
208
+ )
209
+ future_values = np.full(
210
+ (batch_size, prediction_length, num_channels), np.nan, dtype=np.float32
211
+ )
212
+ history_mask = np.zeros((batch_size, max_history_len), dtype=bool)
213
+
214
+ # Second pass: fill arrays
215
+ for i, entry in enumerate(data_entries):
216
+ target = np.asarray(entry["target"], dtype=np.float32)
217
+ if target.ndim == 1:
218
+ target = target.reshape(1, -1)
219
+
220
+ # Truncate to last effective_max_context points if needed
221
+ full_seq_len = target.shape[1]
222
+ total_len_allowed = (
223
+ self.max_context_length
224
+ if self.max_context_length is not None
225
+ else full_seq_len
226
+ )
227
+ total_len_for_entry = min(full_seq_len, total_len_allowed)
228
+
229
+ if total_len_for_entry < prediction_length + 1:
230
+ # Not enough length to build (history + future). Signal to caller.
231
+ raise ValueError(
232
+ "Entry too short after max_context_length truncation to form history+future window"
233
+ )
234
+
235
+ truncated = target[:, -total_len_for_entry:]
236
+ cur_history_len = total_len_for_entry - prediction_length
237
+
238
+ hist = truncated[:, :cur_history_len] # [C, H]
239
+ fut = truncated[
240
+ :, cur_history_len : cur_history_len + prediction_length
241
+ ] # [C, P]
242
+
243
+ # Write into batch arrays with time last -> transpose to [H, C] / [P, C]
244
+ history_values[i, :cur_history_len, :] = hist.T
245
+ future_values[i, :, :] = fut.T
246
+ history_mask[i, :cur_history_len] = True
247
+
248
+ # Get start timestamp and frequency (replicate across batch)
249
+ start_timestamp = data_entries[0]["start"]
250
+ if hasattr(start_timestamp, "to_timestamp"):
251
+ start_numpy = start_timestamp.to_timestamp().to_numpy()
252
+ else:
253
+ start_numpy = pd.Timestamp(start_timestamp).to_numpy()
254
+ start_list = [start_numpy for _ in range(batch_size)]
255
+
256
+ # Get frequency enum and replicate across batch
257
+ frequency_enum = parse_frequency(dataset_freq)
258
+ frequency_list = [frequency_enum for _ in range(batch_size)]
259
+
260
+ # Create the container
261
+ return BatchTimeSeriesContainer(
262
+ history_values=torch.tensor(history_values, dtype=torch.float32),
263
+ future_values=torch.tensor(future_values, dtype=torch.float32),
264
+ start=start_list,
265
+ frequency=frequency_list,
266
+ history_mask=torch.tensor(history_mask, dtype=torch.bool)
267
+ if self.mode == "train"
268
+ else None,
269
+ )
270
+
271
+ def _prepare_epoch_data(self) -> None:
272
+ """Prepare all batches for one epoch."""
273
+ self._epoch_data = []
274
+
275
+ for dataset_key, dataset in self.datasets.items():
276
+ try:
277
+ # Get appropriate dataset based on mode
278
+ if self.mode == "train":
279
+ data = dataset.training_dataset
280
+ else:
281
+ data = dataset.validation_dataset
282
+
283
+ # Collect all valid data entries
284
+ valid_entries = []
285
+ dataset_freq = dataset.freq
286
+ prediction_length = self.dataset_prediction_lengths[dataset_key]
287
+
288
+ for entry in data:
289
+ # Skip if contains NaN and configured to do so
290
+ if self.skip_datasets_with_nans and self._contains_nan(entry):
291
+ continue
292
+
293
+ # Check if we have enough data
294
+ target = np.asarray(entry["target"])
295
+ if target.ndim == 1:
296
+ seq_len = len(target)
297
+ else:
298
+ seq_len = target.shape[1]
299
+
300
+ # Need at least prediction_length + 1 for training
301
+ if self.mode == "train" and seq_len < prediction_length + 1:
302
+ continue
303
+
304
+ valid_entries.append(entry)
305
+
306
+ if not valid_entries:
307
+ logger.warning(f"No valid entries found for {dataset_key}")
308
+ continue
309
+
310
+ # Create batches
311
+ for i in range(0, len(valid_entries), self.batch_size):
312
+ batch_entries = valid_entries[i : i + self.batch_size]
313
+ try:
314
+ batch_container = self._convert_to_container(
315
+ batch_entries, prediction_length, dataset_freq
316
+ )
317
+ self._epoch_data.append((dataset_key, batch_container))
318
+ except Exception as e:
319
+ logger.warning(
320
+ f"Failed to create batch for {dataset_key}: {str(e)}"
321
+ )
322
+ continue
323
+
324
+ except Exception as e:
325
+ logger.warning(
326
+ f"Failed to process dataset {dataset_key}: {str(e)}. "
327
+ f"Dataset may be too short for the required offset."
328
+ )
329
+ continue
330
+
331
+ # Shuffle if in training mode
332
+ if self.mode == "train" and self.shuffle:
333
+ random.shuffle(self._epoch_data)
334
+
335
+ logger.info(f"Prepared {len(self._epoch_data)} batches for {self.mode} mode")
336
+
337
+ def __iter__(self) -> Iterator[BatchTimeSeriesContainer]:
338
+ """Iterate through batches for one epoch."""
339
+ # Reset index at the start of each epoch
340
+ self._current_idx = 0
341
+
342
+ # Reshuffle data for each new epoch if in training mode
343
+ if self.mode == "train" and self.shuffle:
344
+ random.shuffle(self._epoch_data)
345
+
346
+ return self
347
+
348
+ def __next__(self) -> BatchTimeSeriesContainer:
349
+ """Get next batch."""
350
+ if not self._epoch_data:
351
+ raise StopIteration("No valid data available")
352
+
353
+ # Check if we've exhausted the epoch
354
+ if self._current_idx >= len(self._epoch_data):
355
+ raise StopIteration
356
+
357
+ # Get current batch
358
+ dataset_key, batch = self._epoch_data[self._current_idx]
359
+ self._current_idx += 1
360
+
361
+ # Move to device if specified
362
+ if self.device is not None:
363
+ batch.to_device(self.device)
364
+
365
+ return batch
366
+
367
+ def __len__(self) -> int:
368
+ """Return number of batches per epoch."""
369
+ return len(self._epoch_data)
370
+
371
+
372
+ class CyclicGiftEvalDataLoader:
373
+ """
374
+ Wrapper for GiftEvalDataLoader that provides cycling behavior for training.
375
+ This allows training for a fixed number of iterations per epoch, cycling through
376
+ the available data as needed.
377
+ """
378
+
379
+ def __init__(self, base_loader: GiftEvalDataLoader, num_iterations_per_epoch: int):
380
+ """
381
+ Initialize the cyclic data loader.
382
+
383
+ Args:
384
+ base_loader: The underlying GiftEvalDataLoader
385
+ num_iterations_per_epoch: Number of iterations to run per epoch
386
+ """
387
+ self.base_loader = base_loader
388
+ self.num_iterations_per_epoch = num_iterations_per_epoch
389
+ self.dataset_names = base_loader.dataset_names
390
+ self.device = base_loader.device
391
+
392
+ def __iter__(self) -> Iterator[BatchTimeSeriesContainer]:
393
+ """Iterate for exactly num_iterations_per_epoch iterations."""
394
+ self._current_iteration = 0
395
+ self._base_iter = iter(self.base_loader)
396
+ return self
397
+
398
+ def __next__(self) -> BatchTimeSeriesContainer:
399
+ """Get next batch, cycling through base loader as needed."""
400
+ if self._current_iteration >= self.num_iterations_per_epoch:
401
+ raise StopIteration
402
+
403
+ try:
404
+ batch = next(self._base_iter)
405
+ except StopIteration:
406
+ # Restart the base iterator when exhausted
407
+ self._base_iter = iter(self.base_loader)
408
+ batch = next(self._base_iter)
409
+
410
+ self._current_iteration += 1
411
+ return batch
412
+
413
+ def __len__(self) -> int:
414
+ """Return the configured number of iterations per epoch."""
415
+ return self.num_iterations_per_epoch
416
+
417
+
418
+ def create_synthetic_dataloader(
419
+ base_data_dir: str,
420
+ batch_size: int = 128,
421
+ num_batches_per_epoch: int = 1000,
422
+ generator_proportions: Optional[Dict[str, float]] = None,
423
+ mixed_batches: bool = True,
424
+ augmentations: Optional[Dict[str, bool]] = None,
425
+ augmentation_probabilities: Optional[Dict[str, float]] = None,
426
+ device: Optional[torch.device] = None,
427
+ num_workers: int = 0,
428
+ pin_memory: bool = True,
429
+ global_seed: int = 42,
430
+ nan_stats_path: Optional[str] = None,
431
+ nan_patterns_path: Optional[str] = None,
432
+ chosen_scaler_name: Optional[str] = None,
433
+ ) -> torch.utils.data.DataLoader:
434
+ """
435
+ Create a PyTorch DataLoader for training with saved generator batches.
436
+
437
+ Args:
438
+ base_data_dir: Base directory containing generator subdirectories
439
+ batch_size: Size of each training batch
440
+ num_batches_per_epoch: Number of batches per epoch
441
+ generator_proportions: Dict mapping generator names to proportions
442
+ mixed_batches: Whether to create mixed or uniform batches
443
+ augmentations: Dict mapping augmentation names to booleans
444
+ augmentation_probabilities: Dict mapping augmentation names to probabilities
445
+ device: Target device
446
+ num_workers: Number of DataLoader workers
447
+ pin_memory: Whether to pin memory
448
+ global_seed: Global random seed
449
+ nan_stats_path: Path to nan stats file
450
+ chosen_scaler_name: Name of the scaler that used in training
451
+
452
+ Returns:
453
+ PyTorch DataLoader
454
+ """
455
+
456
+ # Create batch composer
457
+ composer = BatchComposer(
458
+ base_data_dir=base_data_dir,
459
+ generator_proportions=generator_proportions,
460
+ mixed_batches=mixed_batches,
461
+ device=device,
462
+ augmentations=augmentations,
463
+ augmentation_probabilities=augmentation_probabilities,
464
+ global_seed=global_seed,
465
+ nan_stats_path=nan_stats_path,
466
+ nan_patterns_path=nan_patterns_path,
467
+ chosen_scaler_name=chosen_scaler_name,
468
+ )
469
+
470
+ # Create dataset
471
+ dataset = ComposedDataset(
472
+ batch_composer=composer,
473
+ num_batches_per_epoch=num_batches_per_epoch,
474
+ batch_size=batch_size,
475
+ )
476
+
477
+ # Custom collate function for BatchTimeSeriesContainer
478
+ def collate_fn(batch):
479
+ """Custom collate function that returns a single BatchTimeSeriesContainer."""
480
+ # Since each item is already a BatchTimeSeriesContainer with batch_size samples,
481
+ # and DataLoader batch_size=1, we just return the first (and only) item
482
+ return batch[0]
483
+
484
+ # Create DataLoader
485
+ dataloader = torch.utils.data.DataLoader(
486
+ dataset,
487
+ batch_size=1, # Each dataset item is already a complete batch
488
+ shuffle=False,
489
+ num_workers=num_workers,
490
+ pin_memory=pin_memory,
491
+ collate_fn=collate_fn,
492
+ drop_last=False,
493
+ )
494
+
495
+ logger.info(
496
+ f"Created DataLoader with {len(dataset)} batches per epoch, "
497
+ f"batch_size={batch_size}, mixed_batches={mixed_batches}"
498
+ )
499
+
500
+ return dataloader
501
+
502
+
503
+ class SyntheticValidationDataset(torch.utils.data.Dataset):
504
+ """
505
+ Fixed synthetic validation dataset that generates a small number of batches
506
+ using the same composition approach as training data.
507
+ """
508
+
509
+ def __init__(
510
+ self,
511
+ base_data_dir: str,
512
+ batch_size: int = 128,
513
+ num_batches: int = 2,
514
+ future_length: int = 512,
515
+ generator_proportions: Optional[Dict[str, float]] = None,
516
+ augmentations: Optional[Dict[str, bool]] = None,
517
+ augmentation_probabilities: Optional[Dict[str, float]] = None,
518
+ device: Optional[torch.device] = None,
519
+ global_seed: int = 42,
520
+ chosen_scaler_name: Optional[str] = None,
521
+ nan_stats_path: Optional[str] = None,
522
+ nan_patterns_path: Optional[str] = None,
523
+ rank: int = 0,
524
+ world_size: int = 1,
525
+ ):
526
+ """
527
+ Initialize the validation dataset.
528
+
529
+ Args:
530
+ base_data_dir: Base directory containing generator subdirectories
531
+ batch_size: Size of each validation batch
532
+ num_batches: Number of validation batches to generate (1 or 2)
533
+ generator_proportions: Dict mapping generator names to proportions
534
+ device: Device to load tensors to
535
+ global_seed: Global random seed
536
+ chosen_scaler_name: Name of the scaler that used in training
537
+ """
538
+ self.batch_size = batch_size
539
+ self.num_batches = num_batches
540
+ self.device = device
541
+
542
+ # Create batch composer; force validation to use max-length windows (no length shortening)
543
+ val_augmentations = dict(augmentations or {})
544
+ val_augmentations["length_shortening"] = False
545
+
546
+ self.batch_composer = BatchComposer(
547
+ base_data_dir=base_data_dir,
548
+ generator_proportions=generator_proportions,
549
+ mixed_batches=True, # Use mixed batches for validation
550
+ device=device,
551
+ global_seed=global_seed + 999999,
552
+ augmentations=val_augmentations,
553
+ augmentation_probabilities=augmentation_probabilities,
554
+ nan_stats_path=nan_stats_path,
555
+ nan_patterns_path=nan_patterns_path,
556
+ chosen_scaler_name=chosen_scaler_name,
557
+ rank=rank,
558
+ world_size=world_size,
559
+ )
560
+
561
+ # Pre-generate fixed validation batches
562
+ self.validation_batches = []
563
+ for i in range(num_batches):
564
+ batch, _ = self.batch_composer.create_batch(
565
+ batch_size=batch_size,
566
+ future_length=future_length,
567
+ seed=global_seed
568
+ + 999999
569
+ + i, # Fixed seeds for reproducible validation
570
+ )
571
+ self.validation_batches.append(batch)
572
+
573
+ logger.info(
574
+ f"Created {num_batches} fixed validation batches with batch_size={batch_size}"
575
+ )
576
+
577
+ def __len__(self) -> int:
578
+ return self.num_batches
579
+
580
+ def __getitem__(self, idx: int) -> BatchTimeSeriesContainer:
581
+ """
582
+ Get a pre-generated validation batch by index.
583
+
584
+ Args:
585
+ idx: Batch index
586
+
587
+ Returns:
588
+ BatchTimeSeriesContainer
589
+ """
590
+ if idx >= len(self.validation_batches):
591
+ raise IndexError(f"Batch index {idx} out of range")
592
+
593
+ batch = self.validation_batches[idx]
594
+
595
+ # Move to device if needed
596
+ if self.device is not None:
597
+ batch.to_device(self.device)
598
+
599
+ return batch
600
+
601
+
602
+ def create_synthetic_dataset(
603
+ base_data_dir: str,
604
+ batch_size: int = 128,
605
+ num_batches_per_epoch: int = 1000,
606
+ generator_proportions: Optional[Dict[str, float]] = None,
607
+ mixed_batches: bool = True,
608
+ augmentations: Optional[Dict[str, bool]] = None,
609
+ augmentation_probabilities: Optional[Dict[str, float]] = None,
610
+ global_seed: int = 42,
611
+ nan_stats_path: Optional[str] = None,
612
+ nan_patterns_path: Optional[str] = None,
613
+ chosen_scaler_name: Optional[str] = None,
614
+ rank: int = 0,
615
+ world_size: int = 1,
616
+ ) -> ComposedDataset:
617
+ """
618
+ Creates the ComposedDataset for training with saved generator batches.
619
+
620
+ Args:
621
+ base_data_dir: Base directory containing generator subdirectories.
622
+ batch_size: Size of each training batch.
623
+ num_batches_per_epoch: Number of batches per epoch.
624
+ generator_proportions: Dict mapping generator names to proportions.
625
+ mixed_batches: Whether to create mixed or uniform batches.
626
+ augmentations: Dict mapping augmentation names to booleans.
627
+ global_seed: Global random seed.
628
+ nan_stats_path: Path to nan stats file.
629
+ chosen_scaler_name: Name of the scaler to use.
630
+ Returns:
631
+ A ComposedDataset instance.
632
+ """
633
+ # Create batch composer
634
+ composer = BatchComposer(
635
+ base_data_dir=base_data_dir,
636
+ generator_proportions=generator_proportions,
637
+ mixed_batches=mixed_batches,
638
+ device=None, # Device is handled in the training loop
639
+ augmentations=augmentations,
640
+ augmentation_probabilities=augmentation_probabilities,
641
+ global_seed=global_seed,
642
+ nan_stats_path=nan_stats_path,
643
+ nan_patterns_path=nan_patterns_path,
644
+ chosen_scaler_name=chosen_scaler_name,
645
+ rank=rank,
646
+ world_size=world_size,
647
+ )
648
+
649
+ # Create and return the dataset
650
+ dataset = ComposedDataset(
651
+ batch_composer=composer,
652
+ num_batches_per_epoch=num_batches_per_epoch,
653
+ batch_size=batch_size,
654
+ )
655
+
656
+ logger.info(
657
+ f"Created ComposedDataset with {len(dataset)} batches per epoch, "
658
+ f"batch_size={batch_size}, mixed_batches={mixed_batches}"
659
+ )
660
+
661
+ return dataset
src/data/scalers.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional
3
+
4
+ import torch
5
+
6
+
7
+ class BaseScaler(ABC):
8
+ """
9
+ Abstract base class for time series scalers.
10
+
11
+ Defines the interface for scaling multivariate time series data with support
12
+ for masked values and channel-wise scaling.
13
+ """
14
+
15
+ @abstractmethod
16
+ def compute_statistics(
17
+ self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
18
+ ) -> Dict[str, torch.Tensor]:
19
+ """
20
+ Compute scaling statistics from historical data.
21
+ """
22
+ pass
23
+
24
+ @abstractmethod
25
+ def scale(
26
+ self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
27
+ ) -> torch.Tensor:
28
+ """
29
+ Apply scaling transformation to data.
30
+ """
31
+ pass
32
+
33
+ @abstractmethod
34
+ def inverse_scale(
35
+ self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
36
+ ) -> torch.Tensor:
37
+ """
38
+ Apply inverse scaling transformation to recover original scale.
39
+ """
40
+ pass
41
+
42
+
43
+ class RobustScaler(BaseScaler):
44
+ """
45
+ Robust scaler using median and IQR for normalization.
46
+ """
47
+
48
+ def __init__(self, epsilon: float = 1e-6, min_scale: float = 1e-3):
49
+ if epsilon <= 0:
50
+ raise ValueError("epsilon must be positive")
51
+ if min_scale <= 0:
52
+ raise ValueError("min_scale must be positive")
53
+ self.epsilon = epsilon
54
+ self.min_scale = min_scale
55
+
56
+ def compute_statistics(
57
+ self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
58
+ ) -> Dict[str, torch.Tensor]:
59
+ """
60
+ Compute median and IQR statistics from historical data with improved numerical stability.
61
+ """
62
+ batch_size, seq_len, num_channels = history_values.shape
63
+ device = history_values.device
64
+
65
+ medians = torch.zeros(batch_size, 1, num_channels, device=device)
66
+ iqrs = torch.ones(batch_size, 1, num_channels, device=device)
67
+
68
+ for b in range(batch_size):
69
+ for c in range(num_channels):
70
+ channel_data = history_values[b, :, c]
71
+
72
+ if history_mask is not None:
73
+ mask = history_mask[b, :].bool()
74
+ valid_data = channel_data[mask]
75
+ else:
76
+ valid_data = channel_data
77
+
78
+ if len(valid_data) == 0:
79
+ continue
80
+
81
+ valid_data = valid_data[torch.isfinite(valid_data)]
82
+
83
+ if len(valid_data) == 0:
84
+ continue
85
+
86
+ median_val = torch.median(valid_data)
87
+ medians[b, 0, c] = median_val
88
+
89
+ if len(valid_data) > 1:
90
+ try:
91
+ q75 = torch.quantile(valid_data, 0.75)
92
+ q25 = torch.quantile(valid_data, 0.25)
93
+ iqr_val = q75 - q25
94
+ iqr_val = torch.max(
95
+ iqr_val, torch.tensor(self.min_scale, device=device)
96
+ )
97
+ iqrs[b, 0, c] = iqr_val
98
+ except Exception:
99
+ std_val = torch.std(valid_data)
100
+ iqrs[b, 0, c] = torch.max(
101
+ std_val, torch.tensor(self.min_scale, device=device)
102
+ )
103
+ else:
104
+ iqrs[b, 0, c] = self.min_scale
105
+
106
+ return {"median": medians, "iqr": iqrs}
107
+
108
+ def scale(
109
+ self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
110
+ ) -> torch.Tensor:
111
+ """
112
+ Apply robust scaling: (data - median) / (iqr + epsilon).
113
+ """
114
+ median = statistics["median"]
115
+ iqr = statistics["iqr"]
116
+
117
+ denominator = torch.max(
118
+ iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)
119
+ )
120
+ scaled_data = (data - median) / denominator
121
+ scaled_data = torch.clamp(scaled_data, -50.0, 50.0)
122
+
123
+ return scaled_data
124
+
125
+ def inverse_scale(
126
+ self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
127
+ ) -> torch.Tensor:
128
+ """
129
+ Apply inverse robust scaling, now compatible with 3D or 4D tensors.
130
+ """
131
+ median = statistics["median"]
132
+ iqr = statistics["iqr"]
133
+
134
+ denominator = torch.max(
135
+ iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)
136
+ )
137
+
138
+ if scaled_data.ndim == 4:
139
+ denominator = denominator.unsqueeze(-1)
140
+ median = median.unsqueeze(-1)
141
+
142
+ return scaled_data * denominator + median
143
+
144
+
145
+ class MinMaxScaler(BaseScaler):
146
+ """
147
+ Min-Max scaler that normalizes data to the range [-1, 1].
148
+ """
149
+
150
+ def __init__(self, epsilon: float = 1e-8):
151
+ if epsilon <= 0:
152
+ raise ValueError("epsilon must be positive")
153
+ self.epsilon = epsilon
154
+
155
+ def compute_statistics(
156
+ self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
157
+ ) -> Dict[str, torch.Tensor]:
158
+ """
159
+ Compute min and max statistics from historical data.
160
+ """
161
+ batch_size, seq_len, num_channels = history_values.shape
162
+ device = history_values.device
163
+
164
+ mins = torch.zeros(batch_size, 1, num_channels, device=device)
165
+ maxs = torch.ones(batch_size, 1, num_channels, device=device)
166
+
167
+ for b in range(batch_size):
168
+ for c in range(num_channels):
169
+ channel_data = history_values[b, :, c]
170
+
171
+ if history_mask is not None:
172
+ mask = history_mask[b, :].bool()
173
+ valid_data = channel_data[mask]
174
+ else:
175
+ valid_data = channel_data
176
+
177
+ if len(valid_data) == 0:
178
+ continue
179
+
180
+ min_val = torch.min(valid_data)
181
+ max_val = torch.max(valid_data)
182
+
183
+ mins[b, 0, c] = min_val
184
+ maxs[b, 0, c] = max_val
185
+
186
+ if torch.abs(max_val - min_val) < self.epsilon:
187
+ maxs[b, 0, c] = min_val + 1.0
188
+
189
+ return {"min": mins, "max": maxs}
190
+
191
+ def scale(
192
+ self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
193
+ ) -> torch.Tensor:
194
+ """
195
+ Apply min-max scaling to range [-1, 1].
196
+ """
197
+ min_val = statistics["min"]
198
+ max_val = statistics["max"]
199
+
200
+ normalized = (data - min_val) / (max_val - min_val + self.epsilon)
201
+ return normalized * 2.0 - 1.0
202
+
203
+ def inverse_scale(
204
+ self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
205
+ ) -> torch.Tensor:
206
+ """
207
+ Apply inverse min-max scaling, now compatible with 3D or 4D tensors.
208
+ """
209
+ min_val = statistics["min"]
210
+ max_val = statistics["max"]
211
+
212
+ if scaled_data.ndim == 4:
213
+ min_val = min_val.unsqueeze(-1)
214
+ max_val = max_val.unsqueeze(-1)
215
+
216
+ normalized = (scaled_data + 1.0) / 2.0
217
+ return normalized * (max_val - min_val + self.epsilon) + min_val
218
+
219
+
220
+ class MeanScaler(BaseScaler):
221
+ """
222
+ A scaler that centers the data by subtracting the channel-wise mean.
223
+
224
+ This scaler only performs centering and does not affect the scale of the data.
225
+ """
226
+
227
+ def compute_statistics(
228
+ self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
229
+ ) -> Dict[str, torch.Tensor]:
230
+ """
231
+ Compute the mean for each channel from historical data.
232
+ """
233
+ batch_size, seq_len, num_channels = history_values.shape
234
+ device = history_values.device
235
+
236
+ # Initialize a tensor to store the mean for each channel in each batch item
237
+ means = torch.zeros(batch_size, 1, num_channels, device=device)
238
+
239
+ for b in range(batch_size):
240
+ for c in range(num_channels):
241
+ channel_data = history_values[b, :, c]
242
+
243
+ # Use the mask to select only valid (observed) data points
244
+ if history_mask is not None:
245
+ mask = history_mask[b, :].bool()
246
+ valid_data = channel_data[mask]
247
+ else:
248
+ valid_data = channel_data
249
+
250
+ # Skip if there's no valid data for this channel
251
+ if len(valid_data) == 0:
252
+ continue
253
+
254
+ # Filter out non-finite values like NaN or Inf before computing
255
+ valid_data = valid_data[torch.isfinite(valid_data)]
256
+
257
+ if len(valid_data) == 0:
258
+ continue
259
+
260
+ # Compute the mean and store it
261
+ means[b, 0, c] = torch.mean(valid_data)
262
+
263
+ return {"mean": means}
264
+
265
+ def scale(
266
+ self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
267
+ ) -> torch.Tensor:
268
+ """
269
+ Apply mean centering: data - mean.
270
+ """
271
+ mean = statistics["mean"]
272
+ return data - mean
273
+
274
+ def inverse_scale(
275
+ self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
276
+ ) -> torch.Tensor:
277
+ """
278
+ Apply inverse mean centering: scaled_data + mean.
279
+
280
+ Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors.
281
+ """
282
+ mean = statistics["mean"]
283
+
284
+ # Adjust shape for 4D tensors (batch, seq_len, channels, samples)
285
+ if scaled_data.ndim == 4:
286
+ mean = mean.unsqueeze(-1)
287
+
288
+ return scaled_data + mean
289
+
290
+
291
+ class MedianScaler(BaseScaler):
292
+ """
293
+ A scaler that centers the data by subtracting the channel-wise median.
294
+
295
+ This scaler only performs centering and does not affect the scale of the data.
296
+ It is more robust to outliers than the MeanScaler.
297
+ """
298
+
299
+ def compute_statistics(
300
+ self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
301
+ ) -> Dict[str, torch.Tensor]:
302
+ """
303
+ Compute the median for each channel from historical data.
304
+ """
305
+ batch_size, seq_len, num_channels = history_values.shape
306
+ device = history_values.device
307
+
308
+ # Initialize a tensor to store the median for each channel in each batch item
309
+ medians = torch.zeros(batch_size, 1, num_channels, device=device)
310
+
311
+ for b in range(batch_size):
312
+ for c in range(num_channels):
313
+ channel_data = history_values[b, :, c]
314
+
315
+ # Use the mask to select only valid (observed) data points
316
+ if history_mask is not None:
317
+ mask = history_mask[b, :].bool()
318
+ valid_data = channel_data[mask]
319
+ else:
320
+ valid_data = channel_data
321
+
322
+ # Skip if there's no valid data for this channel
323
+ if len(valid_data) == 0:
324
+ continue
325
+
326
+ # Filter out non-finite values like NaN or Inf before computing
327
+ valid_data = valid_data[torch.isfinite(valid_data)]
328
+
329
+ if len(valid_data) == 0:
330
+ continue
331
+
332
+ # Compute the median and store it
333
+ medians[b, 0, c] = torch.median(valid_data)
334
+
335
+ return {"median": medians}
336
+
337
+ def scale(
338
+ self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
339
+ ) -> torch.Tensor:
340
+ """
341
+ Apply median centering: data - median.
342
+ """
343
+ median = statistics["median"]
344
+ return data - median
345
+
346
+ def inverse_scale(
347
+ self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
348
+ ) -> torch.Tensor:
349
+ """
350
+ Apply inverse median centering: scaled_data + median.
351
+
352
+ Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors.
353
+ """
354
+ median = statistics["median"]
355
+
356
+ # Adjust shape for 4D tensors (batch, seq_len, channels, samples)
357
+ if scaled_data.ndim == 4:
358
+ median = median.unsqueeze(-1)
359
+
360
+ return scaled_data + median
src/data/time_features.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scipy.fft as fft
7
+ import torch
8
+ from gluonts.time_feature import time_features_from_frequency_str
9
+ from gluonts.time_feature._base import (
10
+ day_of_month,
11
+ day_of_month_index,
12
+ day_of_week,
13
+ day_of_week_index,
14
+ day_of_year,
15
+ hour_of_day,
16
+ hour_of_day_index,
17
+ minute_of_hour,
18
+ minute_of_hour_index,
19
+ month_of_year,
20
+ month_of_year_index,
21
+ second_of_minute,
22
+ second_of_minute_index,
23
+ week_of_year,
24
+ week_of_year_index,
25
+ )
26
+ from gluonts.time_feature.holiday import (
27
+ BLACK_FRIDAY,
28
+ CHRISTMAS_DAY,
29
+ CHRISTMAS_EVE,
30
+ CYBER_MONDAY,
31
+ EASTER_MONDAY,
32
+ EASTER_SUNDAY,
33
+ GOOD_FRIDAY,
34
+ INDEPENDENCE_DAY,
35
+ LABOR_DAY,
36
+ MEMORIAL_DAY,
37
+ NEW_YEARS_DAY,
38
+ NEW_YEARS_EVE,
39
+ THANKSGIVING,
40
+ SpecialDateFeatureSet,
41
+ exponential_kernel,
42
+ squared_exponential_kernel,
43
+ )
44
+ from gluonts.time_feature.seasonality import get_seasonality
45
+ from scipy.signal import find_peaks
46
+
47
+ from src.data.constants import BASE_END_DATE, BASE_START_DATE
48
+ from src.data.frequency import (
49
+ Frequency,
50
+ validate_frequency_safety,
51
+ )
52
+ from src.utils.utils import device
53
+
54
+ # Configure logging
55
+ logging.basicConfig(
56
+ level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
57
+ )
58
+ logger = logging.getLogger(__name__)
59
+
60
+
61
+ # Enhanced feature sets for different frequencies
62
+ ENHANCED_TIME_FEATURES = {
63
+ # High-frequency features (seconds, minutes)
64
+ "high_freq": {
65
+ "normalized": [
66
+ second_of_minute,
67
+ minute_of_hour,
68
+ hour_of_day,
69
+ day_of_week,
70
+ day_of_month,
71
+ ],
72
+ "index": [
73
+ second_of_minute_index,
74
+ minute_of_hour_index,
75
+ hour_of_day_index,
76
+ day_of_week_index,
77
+ ],
78
+ },
79
+ # Medium-frequency features (hourly, daily)
80
+ "medium_freq": {
81
+ "normalized": [
82
+ hour_of_day,
83
+ day_of_week,
84
+ day_of_month,
85
+ day_of_year,
86
+ month_of_year,
87
+ ],
88
+ "index": [
89
+ hour_of_day_index,
90
+ day_of_week_index,
91
+ day_of_month_index,
92
+ week_of_year_index,
93
+ ],
94
+ },
95
+ # Low-frequency features (weekly, monthly)
96
+ "low_freq": {
97
+ "normalized": [day_of_week, day_of_month, month_of_year, week_of_year],
98
+ "index": [day_of_week_index, month_of_year_index, week_of_year_index],
99
+ },
100
+ }
101
+
102
+ # Holiday features for different markets/regions
103
+ HOLIDAY_FEATURE_SETS = {
104
+ "us_business": [
105
+ NEW_YEARS_DAY,
106
+ MEMORIAL_DAY,
107
+ INDEPENDENCE_DAY,
108
+ LABOR_DAY,
109
+ THANKSGIVING,
110
+ CHRISTMAS_EVE,
111
+ CHRISTMAS_DAY,
112
+ NEW_YEARS_EVE,
113
+ ],
114
+ "us_retail": [
115
+ NEW_YEARS_DAY,
116
+ EASTER_SUNDAY,
117
+ MEMORIAL_DAY,
118
+ INDEPENDENCE_DAY,
119
+ LABOR_DAY,
120
+ THANKSGIVING,
121
+ BLACK_FRIDAY,
122
+ CYBER_MONDAY,
123
+ CHRISTMAS_EVE,
124
+ CHRISTMAS_DAY,
125
+ NEW_YEARS_EVE,
126
+ ],
127
+ "christian": [
128
+ NEW_YEARS_DAY,
129
+ GOOD_FRIDAY,
130
+ EASTER_SUNDAY,
131
+ EASTER_MONDAY,
132
+ CHRISTMAS_EVE,
133
+ CHRISTMAS_DAY,
134
+ NEW_YEARS_EVE,
135
+ ],
136
+ }
137
+
138
+
139
+ class TimeFeatureGenerator:
140
+ """
141
+ Enhanced time feature generator that leverages full GluonTS capabilities.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ use_enhanced_features: bool = True,
147
+ use_holiday_features: bool = True,
148
+ holiday_set: str = "us_business",
149
+ holiday_kernel: str = "exponential",
150
+ holiday_kernel_alpha: float = 1.0,
151
+ use_index_features: bool = True,
152
+ k_max: int = 15,
153
+ include_seasonality_info: bool = True,
154
+ use_auto_seasonality: bool = False, # New parameter
155
+ max_seasonal_periods: int = 3, # New parameter
156
+ ):
157
+ """
158
+ Initialize enhanced time feature generator.
159
+
160
+ Parameters
161
+ ----------
162
+ use_enhanced_features : bool
163
+ Whether to use frequency-specific enhanced features
164
+ use_holiday_features : bool
165
+ Whether to include holiday features
166
+ holiday_set : str
167
+ Which holiday set to use ('us_business', 'us_retail', 'christian')
168
+ holiday_kernel : str
169
+ Holiday kernel type ('indicator', 'exponential', 'squared_exponential')
170
+ holiday_kernel_alpha : float
171
+ Kernel parameter for exponential kernels
172
+ use_index_features : bool
173
+ Whether to include index-based features alongside normalized ones
174
+ k_max : int
175
+ Maximum number of time features to pad to
176
+ include_seasonality_info : bool
177
+ Whether to include seasonality information as features
178
+ use_auto_seasonality : bool
179
+ Whether to use automatic FFT-based seasonality detection
180
+ max_seasonal_periods : int
181
+ Maximum number of seasonal periods to detect automatically
182
+ """
183
+ self.use_enhanced_features = use_enhanced_features
184
+ self.use_holiday_features = use_holiday_features
185
+ self.holiday_set = holiday_set
186
+ self.use_index_features = use_index_features
187
+ self.k_max = k_max
188
+ self.include_seasonality_info = include_seasonality_info
189
+ self.use_auto_seasonality = use_auto_seasonality
190
+ self.max_seasonal_periods = max_seasonal_periods
191
+
192
+ # Initialize holiday feature set
193
+ self.holiday_feature_set = None
194
+ if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS:
195
+ kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha)
196
+ self.holiday_feature_set = SpecialDateFeatureSet(
197
+ HOLIDAY_FEATURE_SETS[holiday_set], kernel_func
198
+ )
199
+
200
+ def _get_holiday_kernel(self, kernel_type: str, alpha: float):
201
+ """Get holiday kernel function."""
202
+ if kernel_type == "exponential":
203
+ return exponential_kernel(alpha)
204
+ elif kernel_type == "squared_exponential":
205
+ return squared_exponential_kernel(alpha)
206
+ else:
207
+ # Default indicator kernel
208
+ return lambda x: float(x == 0)
209
+
210
+ def _get_feature_category(self, freq_str: str) -> str:
211
+ """Determine feature category based on frequency."""
212
+ if freq_str in ["s", "1min", "5min", "10min", "15min"]:
213
+ return "high_freq"
214
+ elif freq_str in ["h", "D"]:
215
+ return "medium_freq"
216
+ else:
217
+ return "low_freq"
218
+
219
+ def _compute_enhanced_features(
220
+ self, period_index: pd.PeriodIndex, freq_str: str
221
+ ) -> np.ndarray:
222
+ """Compute enhanced time features based on frequency."""
223
+ if not self.use_enhanced_features:
224
+ return np.array([]).reshape(len(period_index), 0)
225
+
226
+ category = self._get_feature_category(freq_str)
227
+ feature_config = ENHANCED_TIME_FEATURES[category]
228
+
229
+ features = []
230
+
231
+ # Add normalized features
232
+ for feat_func in feature_config["normalized"]:
233
+ try:
234
+ feat_values = feat_func(period_index)
235
+ features.append(feat_values)
236
+ except Exception:
237
+ continue
238
+
239
+ # Add index features if enabled
240
+ if self.use_index_features:
241
+ for feat_func in feature_config["index"]:
242
+ try:
243
+ feat_values = feat_func(period_index)
244
+ # Normalize index features to [0, 1] range
245
+ if feat_values.max() > 0:
246
+ feat_values = feat_values / feat_values.max()
247
+ features.append(feat_values)
248
+ except Exception:
249
+ continue
250
+
251
+ if features:
252
+ return np.stack(features, axis=-1)
253
+ else:
254
+ return np.array([]).reshape(len(period_index), 0)
255
+
256
+ def _compute_holiday_features(self, date_range: pd.DatetimeIndex) -> np.ndarray:
257
+ """Compute holiday features."""
258
+ if not self.use_holiday_features or self.holiday_feature_set is None:
259
+ return np.array([]).reshape(len(date_range), 0)
260
+
261
+ try:
262
+ holiday_features = self.holiday_feature_set(date_range)
263
+ return holiday_features.T # Transpose to get [time, features] shape
264
+ except Exception:
265
+ return np.array([]).reshape(len(date_range), 0)
266
+
267
+ def _detect_auto_seasonality(self, time_series_values: np.ndarray) -> list:
268
+ """
269
+ Detect seasonal periods automatically using FFT analysis.
270
+
271
+ Parameters
272
+ ----------
273
+ time_series_values : np.ndarray
274
+ Time series values for seasonality detection
275
+
276
+ Returns
277
+ -------
278
+ list
279
+ List of detected seasonal periods
280
+ """
281
+ if not self.use_auto_seasonality or len(time_series_values) < 10:
282
+ return []
283
+
284
+ try:
285
+ # Remove NaN values
286
+ values = time_series_values[~np.isnan(time_series_values)]
287
+ if len(values) < 10:
288
+ return []
289
+
290
+ # Simple linear detrending
291
+ x = np.arange(len(values))
292
+ coeffs = np.polyfit(x, values, 1)
293
+ trend = np.polyval(coeffs, x)
294
+ detrended = values - trend
295
+
296
+ # Apply Hann window to reduce spectral leakage
297
+ window = np.hanning(len(detrended))
298
+ windowed = detrended * window
299
+
300
+ # Zero padding for better frequency resolution
301
+ padded_length = len(windowed) * 2
302
+ padded_values = np.zeros(padded_length)
303
+ padded_values[: len(windowed)] = windowed
304
+
305
+ # Compute FFT
306
+ fft_values = fft.rfft(padded_values)
307
+ fft_magnitudes = np.abs(fft_values)
308
+ freqs = np.fft.rfftfreq(padded_length)
309
+
310
+ # Exclude DC component
311
+ fft_magnitudes[0] = 0.0
312
+
313
+ # Find peaks with threshold (5% of max magnitude)
314
+ threshold = 0.05 * np.max(fft_magnitudes)
315
+ peak_indices, _ = find_peaks(fft_magnitudes, height=threshold)
316
+
317
+ if len(peak_indices) == 0:
318
+ return []
319
+
320
+ # Sort by magnitude and take top periods
321
+ sorted_indices = peak_indices[
322
+ np.argsort(fft_magnitudes[peak_indices])[::-1]
323
+ ]
324
+ top_indices = sorted_indices[: self.max_seasonal_periods]
325
+
326
+ # Convert frequencies to periods
327
+ periods = []
328
+ for idx in top_indices:
329
+ if freqs[idx] > 0:
330
+ period = 1.0 / freqs[idx]
331
+ # Scale back to original length and round
332
+ period = round(period / 2) # Account for zero padding
333
+ if 2 <= period <= len(values) // 2: # Reasonable period range
334
+ periods.append(period)
335
+
336
+ return list(set(periods)) # Remove duplicates
337
+
338
+ except Exception:
339
+ return []
340
+
341
+ def _compute_seasonality_features(
342
+ self,
343
+ period_index: pd.PeriodIndex,
344
+ freq_str: str,
345
+ time_series_values: np.ndarray = None,
346
+ ) -> np.ndarray:
347
+ """Compute seasonality-aware features."""
348
+ if not self.include_seasonality_info:
349
+ return np.array([]).reshape(len(period_index), 0)
350
+
351
+ all_seasonal_features = []
352
+
353
+ # Original frequency-based seasonality
354
+ try:
355
+ seasonality = get_seasonality(freq_str)
356
+ if seasonality > 1:
357
+ positions = np.arange(len(period_index))
358
+ sin_feat = np.sin(2 * np.pi * positions / seasonality)
359
+ cos_feat = np.cos(2 * np.pi * positions / seasonality)
360
+ all_seasonal_features.extend([sin_feat, cos_feat])
361
+ except Exception:
362
+ pass
363
+
364
+ # Automatic seasonality detection
365
+ if self.use_auto_seasonality and time_series_values is not None:
366
+ auto_periods = self._detect_auto_seasonality(time_series_values)
367
+ for period in auto_periods:
368
+ try:
369
+ positions = np.arange(len(period_index))
370
+ sin_feat = np.sin(2 * np.pi * positions / period)
371
+ cos_feat = np.cos(2 * np.pi * positions / period)
372
+ all_seasonal_features.extend([sin_feat, cos_feat])
373
+ except Exception:
374
+ continue
375
+
376
+ if all_seasonal_features:
377
+ return np.stack(all_seasonal_features, axis=-1)
378
+ else:
379
+ return np.array([]).reshape(len(period_index), 0)
380
+
381
+ def compute_features(
382
+ self,
383
+ period_index: pd.PeriodIndex,
384
+ date_range: pd.DatetimeIndex,
385
+ freq_str: str,
386
+ time_series_values: np.ndarray = None,
387
+ ) -> np.ndarray:
388
+ """
389
+ Compute all time features for given period index.
390
+
391
+ Parameters
392
+ ----------
393
+ period_index : pd.PeriodIndex
394
+ Period index for computing features
395
+ date_range : pd.DatetimeIndex
396
+ Corresponding datetime index for holiday features
397
+ freq_str : str
398
+ Frequency string
399
+ time_series_values : np.ndarray, optional
400
+ Time series values for automatic seasonality detection
401
+
402
+ Returns
403
+ -------
404
+ np.ndarray
405
+ Time features array of shape [time_steps, num_features]
406
+ """
407
+ all_features = []
408
+
409
+ # Standard GluonTS features
410
+ try:
411
+ standard_features = time_features_from_frequency_str(freq_str)
412
+ if standard_features:
413
+ std_feat = np.stack(
414
+ [feat(period_index) for feat in standard_features], axis=-1
415
+ )
416
+ all_features.append(std_feat)
417
+ except Exception:
418
+ pass
419
+
420
+ # Enhanced features
421
+ enhanced_feat = self._compute_enhanced_features(period_index, freq_str)
422
+ if enhanced_feat.shape[1] > 0:
423
+ all_features.append(enhanced_feat)
424
+
425
+ # Holiday features
426
+ holiday_feat = self._compute_holiday_features(date_range)
427
+ if holiday_feat.shape[1] > 0:
428
+ all_features.append(holiday_feat)
429
+
430
+ # Seasonality features (including auto-detected)
431
+ seasonality_feat = self._compute_seasonality_features(
432
+ period_index, freq_str, time_series_values
433
+ )
434
+ if seasonality_feat.shape[1] > 0:
435
+ all_features.append(seasonality_feat)
436
+
437
+ if all_features:
438
+ combined_features = np.concatenate(all_features, axis=-1)
439
+ else:
440
+ combined_features = np.zeros((len(period_index), 1))
441
+
442
+ return combined_features
443
+
444
+
445
+ def compute_batch_time_features(
446
+ start: List[np.datetime64],
447
+ history_length: int,
448
+ future_length: int,
449
+ batch_size: int,
450
+ frequency: List[Frequency],
451
+ K_max: int = 6,
452
+ time_feature_config: Optional[Dict[str, Any]] = None,
453
+ ):
454
+ """
455
+ Compute time features from start timestamps and frequency.
456
+
457
+ Parameters
458
+ ----------
459
+ start : array-like, shape (batch_size,)
460
+ Start timestamps for each batch item.
461
+ history_length : int
462
+ Length of history sequence.
463
+ future_length : int
464
+ Length of target sequence.
465
+ batch_size : int
466
+ Batch size.
467
+ frequency : array-like, shape (batch_size,)
468
+ Frequency of the time series.
469
+ K_max : int, optional
470
+ Maximum number of time features to pad to (default: 6).
471
+ time_feature_config : dict, optional
472
+ Configuration for enhanced time features.
473
+
474
+ Returns
475
+ -------
476
+ tuple
477
+ (history_time_features, target_time_features) where each is a torch.Tensor
478
+ of shape (batch_size, length, K_max).
479
+ """
480
+ # Initialize enhanced feature generator
481
+ feature_config = time_feature_config or {}
482
+ feature_generator = TimeFeatureGenerator(**feature_config)
483
+
484
+ # Generate timestamps and features
485
+ history_features_list = []
486
+ future_features_list = []
487
+ total_length = history_length + future_length
488
+ for i in range(batch_size):
489
+ frequency_i = frequency[i]
490
+ freq_str = frequency_i.to_pandas_freq(for_date_range=True)
491
+ period_freq_str = frequency_i.to_pandas_freq(for_date_range=False)
492
+
493
+ # Validate start timestamp is within safe bounds
494
+ start_ts = pd.Timestamp(start[i])
495
+ if not validate_frequency_safety(start_ts, total_length, frequency_i):
496
+ logger.debug(
497
+ f"Start date {start_ts} not safe for total_length={total_length}, frequency={frequency_i}. "
498
+ f"Using BASE_START_DATE instead."
499
+ )
500
+ start_ts = BASE_START_DATE
501
+
502
+ # Create history range with bounds checking
503
+ history_range = pd.date_range(
504
+ start=start_ts, periods=history_length, freq=freq_str
505
+ )
506
+
507
+ # Check if history range goes beyond safe bounds
508
+ if history_range[-1] > BASE_END_DATE:
509
+ safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (
510
+ history_length + future_length
511
+ )
512
+ if safe_start < BASE_START_DATE:
513
+ safe_start = BASE_START_DATE
514
+ history_range = pd.date_range(
515
+ start=safe_start, periods=history_length, freq=freq_str
516
+ )
517
+
518
+ future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str)
519
+ future_range = pd.date_range(
520
+ start=future_start, periods=future_length, freq=freq_str
521
+ )
522
+
523
+ # Convert to period indices
524
+ history_period_idx = history_range.to_period(period_freq_str)
525
+ future_period_idx = future_range.to_period(period_freq_str)
526
+
527
+ # Compute enhanced features
528
+ history_features = feature_generator.compute_features(
529
+ history_period_idx, history_range, freq_str
530
+ )
531
+ future_features = feature_generator.compute_features(
532
+ future_period_idx, future_range, freq_str
533
+ )
534
+
535
+ # Pad or truncate to K_max
536
+ history_features = _pad_or_truncate_features(history_features, K_max)
537
+ future_features = _pad_or_truncate_features(future_features, K_max)
538
+
539
+ history_features_list.append(history_features)
540
+ future_features_list.append(future_features)
541
+
542
+ # Stack into batch tensors
543
+ history_time_features = np.stack(history_features_list, axis=0)
544
+ future_time_features = np.stack(future_features_list, axis=0)
545
+
546
+ return (
547
+ torch.from_numpy(history_time_features).float().to(device),
548
+ torch.from_numpy(future_time_features).float().to(device),
549
+ )
550
+
551
+
552
+ def _pad_or_truncate_features(features: np.ndarray, K_max: int) -> np.ndarray:
553
+ """Pad with zeros or truncate features to K_max dimensions."""
554
+ seq_len, num_features = features.shape
555
+
556
+ if num_features < K_max:
557
+ # Pad with zeros
558
+ padding = np.zeros((seq_len, K_max - num_features))
559
+ features = np.concatenate([features, padding], axis=-1)
560
+ elif num_features > K_max:
561
+ # Truncate to K_max (keep most important features first)
562
+ features = features[:, :K_max]
563
+
564
+ return features
src/data/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Optional, Tuple, Union
3
+
4
+
5
+ def sample_future_length(
6
+ range: Union[Tuple[int, int], str] = "gift_eval",
7
+ total_length: Optional[int] = None,
8
+ ) -> int:
9
+ """
10
+ Sample a forecast length.
11
+
12
+ - If `range` is a tuple, uniformly sample in [min, max]. When `total_length` is
13
+ provided, enforce a cap so the result is at most floor(0.45 * total_length).
14
+ - If `range` is "gift_eval", sample from a pre-defined weighted set. When
15
+ `total_length` is provided, filter out candidates greater than
16
+ floor(0.45 * total_length) before sampling.
17
+ """
18
+ # Compute the cap when total_length is provided
19
+ cap: Optional[int] = None
20
+ if total_length is not None:
21
+ cap = max(1, int(0.45 * int(total_length)))
22
+
23
+ if isinstance(range, tuple):
24
+ min_len, max_len = range
25
+ if cap is not None:
26
+ effective_max_len = min(max_len, cap)
27
+ # Ensure valid bounds
28
+ if min_len > effective_max_len:
29
+ return effective_max_len
30
+ return random.randint(min_len, effective_max_len)
31
+ return random.randint(min_len, max_len)
32
+ elif range == "gift_eval":
33
+ # Gift eval forecast lengths with their frequencies
34
+ GIFT_EVAL_FORECAST_LENGTHS = {
35
+ 48: 5,
36
+ 720: 38,
37
+ 480: 38,
38
+ 30: 3,
39
+ 300: 16,
40
+ 8: 2,
41
+ 120: 3,
42
+ 450: 8,
43
+ 80: 8,
44
+ 12: 2,
45
+ 900: 10,
46
+ 180: 3,
47
+ 600: 10,
48
+ 60: 3,
49
+ 210: 3,
50
+ 195: 3,
51
+ 140: 3,
52
+ 130: 3,
53
+ 14: 1,
54
+ 18: 1,
55
+ 13: 1,
56
+ 6: 1,
57
+ }
58
+
59
+ lengths = list(GIFT_EVAL_FORECAST_LENGTHS.keys())
60
+ weights = list(GIFT_EVAL_FORECAST_LENGTHS.values())
61
+
62
+ if cap is not None:
63
+ filtered = [
64
+ (length_candidate, weight)
65
+ for length_candidate, weight in zip(lengths, weights)
66
+ if length_candidate <= cap
67
+ ]
68
+ if filtered:
69
+ lengths, weights = zip(*filtered)
70
+ lengths = list(lengths)
71
+ weights = list(weights)
72
+
73
+ return random.choices(lengths, weights=weights)[0]
74
+ else:
75
+ raise ValueError(f"Invalid range: {range}")
src/gift_eval/__init__.py ADDED
File without changes
src/gift_eval/aggregate_results.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import List
6
+
7
+ import pandas as pd
8
+
9
+ from src.gift_eval.constants import (
10
+ ALL_DATASETS,
11
+ DATASET_PROPERTIES,
12
+ MED_LONG_DATASETS,
13
+ PRETTY_NAMES,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_all_datasets_full_name() -> List[str]:
20
+ """Get all possible dataset full names for validation."""
21
+ terms = ["short", "medium", "long"]
22
+ datasets_full_names: List[str] = []
23
+
24
+ for name in ALL_DATASETS:
25
+ for term in terms:
26
+ if term in ["medium", "long"] and name not in MED_LONG_DATASETS.split():
27
+ continue
28
+
29
+ if "/" in name:
30
+ ds_key, ds_freq = name.split("/")
31
+ ds_key = ds_key.lower()
32
+ ds_key = PRETTY_NAMES.get(ds_key, ds_key)
33
+ else:
34
+ ds_key = name.lower()
35
+ ds_key = PRETTY_NAMES.get(ds_key, ds_key)
36
+ ds_freq = DATASET_PROPERTIES[ds_key]["frequency"]
37
+
38
+ datasets_full_names.append(f"{ds_key}/{ds_freq}/{term}")
39
+
40
+ return datasets_full_names
41
+
42
+
43
+ def aggregate_results(
44
+ result_root_dir: str | Path,
45
+ ) -> pd.DataFrame | None:
46
+ """Aggregate results from multiple CSV files.
47
+
48
+ Returns the combined dataframe. Optionally saves to
49
+ <result_root_dir>/all_results.csv
50
+ """
51
+ result_root_dir = Path(result_root_dir)
52
+
53
+ logger.info(f"Aggregating results in: {result_root_dir}")
54
+
55
+ # Find all CSV result files under the provided root directory
56
+ # Results are written per-dataset as <result_root_dir>/<dataset_name>/results.csv
57
+ result_files = glob.glob(f"{result_root_dir}/**/results.csv", recursive=True)
58
+
59
+ if not result_files:
60
+ logger.error("No result files found!")
61
+ return None
62
+
63
+ # Initialize empty list to store dataframes
64
+ dataframes: List[pd.DataFrame] = []
65
+
66
+ # Read and combine all CSV files
67
+ for file in result_files:
68
+ try:
69
+ df = pd.read_csv(file)
70
+ if len(df) > 0:
71
+ dataframes.append(df)
72
+ else:
73
+ logger.warning(f"Empty file: {file}")
74
+ except pd.errors.EmptyDataError:
75
+ logger.warning(f"Skipping empty file: {file}")
76
+ except Exception as e:
77
+ logger.error(f"Error reading {file}: {str(e)}")
78
+
79
+ if dataframes:
80
+ # Combine all dataframes and sort by dataset
81
+ combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")
82
+
83
+ # Check for duplicates
84
+ if len(combined_df) != len(set(combined_df.dataset)):
85
+ duplicate_datasets = combined_df.dataset[
86
+ combined_df.dataset.duplicated()
87
+ ].tolist()
88
+ logger.warning(f"Warning: Duplicate datasets found: {duplicate_datasets}")
89
+ # Remove duplicates, keeping the first occurrence
90
+ combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
91
+ logger.info(
92
+ f"Removed duplicates, {len(combined_df)} unique datasets remaining"
93
+ )
94
+
95
+ logger.info(f"Combined results: {len(combined_df)} datasets")
96
+ else:
97
+ logger.warning("No valid CSV files found to combine")
98
+ return None
99
+
100
+ # Get all expected datasets and compare with completed ones
101
+ all_datasets_full_name = get_all_datasets_full_name()
102
+ completed_experiments = combined_df.dataset.tolist()
103
+
104
+ completed_experiments_clean = [
105
+ exp for exp in completed_experiments if exp in all_datasets_full_name
106
+ ]
107
+ missing_or_failed_experiments = [
108
+ exp for exp in all_datasets_full_name if exp not in completed_experiments_clean
109
+ ]
110
+
111
+ logger.info("=== EXPERIMENT SUMMARY ===")
112
+ logger.info(f"Total expected datasets: {len(all_datasets_full_name)}")
113
+ logger.info(f"Completed experiments: {len(completed_experiments_clean)}")
114
+ logger.info(f"Missing/failed experiments: {len(missing_or_failed_experiments)}")
115
+
116
+ logger.info("Completed experiments:")
117
+ for i, exp in enumerate(completed_experiments_clean):
118
+ logger.info(f" {i + 1:3d}: {exp}")
119
+
120
+ if missing_or_failed_experiments:
121
+ logger.info("Missing or failed experiments:")
122
+ for i, exp in enumerate(missing_or_failed_experiments):
123
+ logger.info(f" {i + 1:3d}: {exp}")
124
+
125
+ # Calculate completion percentage
126
+ completion_rate = (
127
+ len(completed_experiments_clean) / len(all_datasets_full_name) * 100
128
+ )
129
+ logger.info(f"Completion rate: {completion_rate:.1f}%")
130
+
131
+ # Save combined results
132
+ output_file = result_root_dir / "all_results.csv"
133
+ combined_df.to_csv(output_file, index=False)
134
+ logger.info(f"Combined results saved to: {output_file}")
135
+
136
+ return combined_df
137
+
138
+
139
+ if __name__ == "__main__":
140
+ parser = argparse.ArgumentParser(
141
+ description="Aggregate GIFT-Eval results from multiple CSV files"
142
+ )
143
+ parser.add_argument(
144
+ "--result_root_dir",
145
+ type=str,
146
+ required=True,
147
+ help="Root directory containing result subdirectories",
148
+ )
149
+
150
+ args = parser.parse_args()
151
+ args.result_root_dir = Path(args.result_root_dir)
152
+
153
+ logging.basicConfig(
154
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
155
+ )
156
+ logger.info(f"Searching in directory: {args.result_root_dir}")
157
+
158
+ aggregate_results(
159
+ result_root_dir=args.result_root_dir,
160
+ )
src/gift_eval/constants.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+
5
+ from gluonts.ev.metrics import (
6
+ MAE,
7
+ MAPE,
8
+ MASE,
9
+ MSE,
10
+ MSIS,
11
+ ND,
12
+ NRMSE,
13
+ RMSE,
14
+ SMAPE,
15
+ MeanWeightedSumQuantileLoss,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Environment setup
21
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
22
+
23
+
24
+ DATASET_PROPERTIES_PATH = "src/gift_eval/dataset_properties.json"
25
+
26
+
27
+ try:
28
+ with open(DATASET_PROPERTIES_PATH, "r") as f:
29
+ DATASET_PROPERTIES = json.load(f)
30
+ except Exception as e:
31
+ DATASET_PROPERTIES = {}
32
+ logger.warning(
33
+ f"Could not load dataset properties from {DATASET_PROPERTIES_PATH}: {e}. Domain and num_variates will fall back to defaults."
34
+ )
35
+
36
+
37
+ # Datasets
38
+ SHORT_DATASETS = "m4_yearly m4_quarterly m4_monthly m4_weekly m4_daily m4_hourly electricity/15T electricity/H electricity/D electricity/W solar/10T solar/H solar/D solar/W hospital covid_deaths us_births/D us_births/M us_births/W saugeenday/D saugeenday/M saugeenday/W temperature_rain_with_missing kdd_cup_2018_with_missing/H kdd_cup_2018_with_missing/D car_parts_with_missing restaurant hierarchical_sales/D hierarchical_sales/W LOOP_SEATTLE/5T LOOP_SEATTLE/H LOOP_SEATTLE/D SZ_TAXI/15T SZ_TAXI/H M_DENSE/H M_DENSE/D ett1/15T ett1/H ett1/D ett1/W ett2/15T ett2/H ett2/D ett2/W jena_weather/10T jena_weather/H jena_weather/D bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
39
+ MED_LONG_DATASETS = "electricity/15T electricity/H solar/10T solar/H kdd_cup_2018_with_missing/H LOOP_SEATTLE/5T LOOP_SEATTLE/H SZ_TAXI/15T M_DENSE/H ett1/15T ett1/H ett2/15T ett2/H jena_weather/10T jena_weather/H bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
40
+ ALL_DATASETS = list(set(SHORT_DATASETS.split() + MED_LONG_DATASETS.split()))
41
+
42
+ # Evaluation terms
43
+ TERMS = ["short", "medium", "long"]
44
+
45
+ # Pretty names mapping (following GIFT eval standard)
46
+ PRETTY_NAMES = {
47
+ "saugeenday": "saugeen",
48
+ "temperature_rain_with_missing": "temperature_rain",
49
+ "kdd_cup_2018_with_missing": "kdd_cup_2018",
50
+ "car_parts_with_missing": "car_parts",
51
+ }
52
+
53
+
54
+ METRICS = [
55
+ MSE(forecast_type="mean"),
56
+ MSE(forecast_type=0.5),
57
+ MAE(),
58
+ MASE(),
59
+ MAPE(),
60
+ SMAPE(),
61
+ MSIS(),
62
+ RMSE(),
63
+ NRMSE(),
64
+ ND(),
65
+ MeanWeightedSumQuantileLoss(
66
+ quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
67
+ ),
68
+ ]
69
+
70
+
71
+ STANDARD_METRIC_NAMES = [
72
+ "MSE[mean]",
73
+ "MSE[0.5]",
74
+ "MAE[0.5]",
75
+ "MASE[0.5]",
76
+ "MAPE[0.5]",
77
+ "sMAPE[0.5]",
78
+ "MSIS",
79
+ "RMSE[mean]",
80
+ "NRMSE[mean]",
81
+ "ND[0.5]",
82
+ "mean_weighted_sum_quantile_loss",
83
+ ]
src/gift_eval/data.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from collections.abc import Iterable, Iterator
18
+ from enum import Enum
19
+ from functools import cached_property
20
+ from pathlib import Path
21
+ from typing import Optional
22
+
23
+ import datasets
24
+ import pyarrow.compute as pc
25
+ from gluonts.dataset import DataEntry
26
+ from gluonts.dataset.common import ProcessDataEntry
27
+ from gluonts.dataset.split import TestData, TrainingDataset, split
28
+ from gluonts.itertools import Map
29
+ from gluonts.time_feature import norm_freq_str
30
+ from gluonts.transform import Transformation
31
+ from pandas.tseries.frequencies import to_offset
32
+ from toolz import compose
33
+
34
+ TEST_SPLIT = 0.1
35
+ MAX_WINDOW = 20
36
+
37
+ M4_PRED_LENGTH_MAP = {
38
+ "A": 6,
39
+ "Q": 8,
40
+ "M": 18,
41
+ "W": 13,
42
+ "D": 14,
43
+ "H": 48,
44
+ "h": 48,
45
+ "Y": 6,
46
+ }
47
+
48
+ PRED_LENGTH_MAP = {
49
+ "M": 12,
50
+ "W": 8,
51
+ "D": 30,
52
+ "H": 48,
53
+ "h": 48,
54
+ "T": 48,
55
+ "S": 60,
56
+ "s": 60,
57
+ "min": 48,
58
+ }
59
+
60
+ TFB_PRED_LENGTH_MAP = {
61
+ "A": 6,
62
+ "Y": 6,
63
+ "H": 48,
64
+ "h": 48,
65
+ "Q": 8,
66
+ "D": 14,
67
+ "M": 18,
68
+ "W": 13,
69
+ "U": 8,
70
+ "T": 8,
71
+ "min": 8,
72
+ "us": 8,
73
+ }
74
+
75
+
76
+ class Term(Enum):
77
+ SHORT = "short"
78
+ MEDIUM = "medium"
79
+ LONG = "long"
80
+
81
+ @property
82
+ def multiplier(self) -> int:
83
+ if self == Term.SHORT:
84
+ return 1
85
+ elif self == Term.MEDIUM:
86
+ return 10
87
+ elif self == Term.LONG:
88
+ return 15
89
+
90
+
91
+ def itemize_start(data_entry: DataEntry) -> DataEntry:
92
+ data_entry["start"] = data_entry["start"].item()
93
+ return data_entry
94
+
95
+
96
+ class MultivariateToUnivariate(Transformation):
97
+ def __init__(self, field):
98
+ self.field = field
99
+
100
+ def __call__(
101
+ self, data_it: Iterable[DataEntry], is_train: bool = False
102
+ ) -> Iterator:
103
+ for data_entry in data_it:
104
+ item_id = data_entry["item_id"]
105
+ val_ls = list(data_entry[self.field])
106
+ for id, val in enumerate(val_ls):
107
+ univariate_entry = data_entry.copy()
108
+ univariate_entry[self.field] = val
109
+ univariate_entry["item_id"] = item_id + "_dim" + str(id)
110
+ yield univariate_entry
111
+
112
+
113
+ class Dataset:
114
+ def __init__(
115
+ self,
116
+ name: str,
117
+ term: Term | str = Term.SHORT,
118
+ to_univariate: bool = False,
119
+ storage_path: str = None,
120
+ max_windows: Optional[int] = None,
121
+ ):
122
+ storage_path = Path(storage_path)
123
+ self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(
124
+ "numpy"
125
+ )
126
+ process = ProcessDataEntry(
127
+ self.freq,
128
+ one_dim_target=self.target_dim == 1,
129
+ )
130
+
131
+ self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)
132
+ if to_univariate:
133
+ self.gluonts_dataset = MultivariateToUnivariate("target").apply(
134
+ self.gluonts_dataset
135
+ )
136
+
137
+ self.term = Term(term)
138
+ self.name = name
139
+ self.max_windows = max_windows if max_windows is not None else MAX_WINDOW
140
+
141
+ @cached_property
142
+ def prediction_length(self) -> int:
143
+ freq = norm_freq_str(to_offset(self.freq).name)
144
+ if freq.endswith("E"):
145
+ freq = freq[:-1]
146
+ pred_len = (
147
+ M4_PRED_LENGTH_MAP[freq] if "m4" in self.name else PRED_LENGTH_MAP[freq]
148
+ )
149
+ return self.term.multiplier * pred_len
150
+
151
+ @cached_property
152
+ def freq(self) -> str:
153
+ return self.hf_dataset[0]["freq"]
154
+
155
+ @cached_property
156
+ def target_dim(self) -> int:
157
+ return (
158
+ target.shape[0]
159
+ if len((target := self.hf_dataset[0]["target"]).shape) > 1
160
+ else 1
161
+ )
162
+
163
+ @cached_property
164
+ def past_feat_dynamic_real_dim(self) -> int:
165
+ if "past_feat_dynamic_real" not in self.hf_dataset[0]:
166
+ return 0
167
+ elif (
168
+ len(
169
+ (
170
+ past_feat_dynamic_real := self.hf_dataset[0][
171
+ "past_feat_dynamic_real"
172
+ ]
173
+ ).shape
174
+ )
175
+ > 1
176
+ ):
177
+ return past_feat_dynamic_real.shape[0]
178
+ else:
179
+ return 1
180
+
181
+ @cached_property
182
+ def windows(self) -> int:
183
+ if "m4" in self.name:
184
+ return 1
185
+ w = math.ceil(TEST_SPLIT * self._min_series_length / self.prediction_length)
186
+ return min(max(1, w), self.max_windows)
187
+
188
+ @cached_property
189
+ def _min_series_length(self) -> int:
190
+ if self.hf_dataset[0]["target"].ndim > 1:
191
+ lengths = pc.list_value_length(
192
+ pc.list_flatten(
193
+ pc.list_slice(self.hf_dataset.data.column("target"), 0, 1)
194
+ )
195
+ )
196
+ else:
197
+ lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
198
+ return min(lengths.to_numpy())
199
+
200
+ @cached_property
201
+ def sum_series_length(self) -> int:
202
+ if self.hf_dataset[0]["target"].ndim > 1:
203
+ lengths = pc.list_value_length(
204
+ pc.list_flatten(self.hf_dataset.data.column("target"))
205
+ )
206
+ else:
207
+ lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
208
+ return sum(lengths.to_numpy())
209
+
210
+ @property
211
+ def training_dataset(self) -> TrainingDataset:
212
+ training_dataset, _ = split(
213
+ self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1)
214
+ )
215
+ return training_dataset
216
+
217
+ @property
218
+ def validation_dataset(self) -> TrainingDataset:
219
+ validation_dataset, _ = split(
220
+ self.gluonts_dataset, offset=-self.prediction_length * self.windows
221
+ )
222
+ return validation_dataset
223
+
224
+ @property
225
+ def test_data(self) -> TestData:
226
+ _, test_template = split(
227
+ self.gluonts_dataset, offset=-self.prediction_length * self.windows
228
+ )
229
+ test_data = test_template.generate_instances(
230
+ prediction_length=self.prediction_length,
231
+ windows=self.windows,
232
+ distance=self.prediction_length,
233
+ )
234
+ return test_data
src/gift_eval/dataset_properties.json ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "m4_yearly": {
3
+ "domain": "Econ/Fin",
4
+ "frequency": "A",
5
+ "num_variates": 1
6
+ },
7
+ "m4_quarterly": {
8
+ "domain": "Econ/Fin",
9
+ "frequency": "Q",
10
+ "num_variates": 1
11
+ },
12
+ "m4_monthly": {
13
+ "domain": "Econ/Fin",
14
+ "frequency": "M",
15
+ "num_variates": 1
16
+ },
17
+ "m4_weekly": {
18
+ "domain": "Econ/Fin",
19
+ "frequency": "W",
20
+ "num_variates": 1
21
+ },
22
+ "m4_daily": {
23
+ "domain": "Econ/Fin",
24
+ "frequency": "D",
25
+ "num_variates": 1
26
+ },
27
+ "m4_hourly": {
28
+ "domain": "Econ/Fin",
29
+ "frequency": "H",
30
+ "num_variates": 1
31
+ },
32
+ "electricity": {
33
+ "domain": "Energy",
34
+ "frequency": "W",
35
+ "num_variates": 1
36
+ },
37
+ "ett1": {
38
+ "domain": "Energy",
39
+ "frequency": "W",
40
+ "num_variates": 7
41
+ },
42
+ "ett2": {
43
+ "domain": "Energy",
44
+ "frequency": "W",
45
+ "num_variates": 7
46
+ },
47
+ "solar": {
48
+ "domain": "Energy",
49
+ "frequency": "W",
50
+ "num_variates": 1
51
+ },
52
+ "hospital": {
53
+ "domain": "Healthcare",
54
+ "frequency": "M",
55
+ "num_variates": 1
56
+ },
57
+ "covid_deaths": {
58
+ "domain": "Healthcare",
59
+ "frequency": "D",
60
+ "num_variates": 1
61
+ },
62
+ "us_births": {
63
+ "domain": "Healthcare",
64
+ "frequency": "M",
65
+ "num_variates": 1
66
+ },
67
+ "saugeen": {
68
+ "domain": "Nature",
69
+ "frequency": "M",
70
+ "num_variates": 1
71
+ },
72
+ "temperature_rain": {
73
+ "domain": "Nature",
74
+ "frequency": "D",
75
+ "num_variates": 1
76
+ },
77
+ "kdd_cup_2018": {
78
+ "domain": "Nature",
79
+ "frequency": "D",
80
+ "num_variates": 1
81
+ },
82
+ "jena_weather": {
83
+ "domain": "Nature",
84
+ "frequency": "D",
85
+ "num_variates": 21
86
+ },
87
+ "car_parts": {
88
+ "domain": "Sales",
89
+ "frequency": "M",
90
+ "num_variates": 1
91
+ },
92
+ "restaurant": {
93
+ "domain": "Sales",
94
+ "frequency": "D",
95
+ "num_variates": 1
96
+ },
97
+ "hierarchical_sales": {
98
+ "domain": "Sales",
99
+ "frequency": "W-WED",
100
+ "num_variates": 1
101
+ },
102
+ "loop_seattle": {
103
+ "domain": "Transport",
104
+ "frequency": "D",
105
+ "num_variates": 1
106
+ },
107
+ "sz_taxi": {
108
+ "domain": "Transport",
109
+ "frequency": "H",
110
+ "num_variates": 1
111
+ },
112
+ "m_dense": {
113
+ "domain": "Transport",
114
+ "frequency": "D",
115
+ "num_variates": 1
116
+ },
117
+ "bitbrains_fast_storage": {
118
+ "domain": "Web/CloudOps",
119
+ "frequency": "H",
120
+ "num_variates": 2
121
+ },
122
+ "bitbrains_rnd": {
123
+ "domain": "Web/CloudOps",
124
+ "frequency": "H",
125
+ "num_variates": 2
126
+ },
127
+ "bizitobs_application": {
128
+ "domain": "Web/CloudOps",
129
+ "frequency": "10S",
130
+ "num_variates": 2
131
+ },
132
+ "bizitobs_service": {
133
+ "domain": "Web/CloudOps",
134
+ "frequency": "10S",
135
+ "num_variates": 2
136
+ },
137
+ "bizitobs_l2c": {
138
+ "domain": "Web/CloudOps",
139
+ "frequency": "H",
140
+ "num_variates": 7
141
+ },
142
+ "dd_benchmark_short": {
143
+ "domain": "Web/Observability",
144
+ "frequency": "Short",
145
+ "num_variates": 32
146
+ },
147
+ "dd_benchmark_long": {
148
+ "domain": "Web/Observability",
149
+ "frequency": "Long",
150
+ "num_variates": 32
151
+ }
152
+ }
src/gift_eval/evaluate.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import logging
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import matplotlib
10
+ import matplotlib.pyplot as plt
11
+ from gluonts.model.evaluation import evaluate_model
12
+ from gluonts.time_feature import get_seasonality
13
+ from linear_operator.utils.cholesky import NumericalWarning
14
+
15
+ from src.gift_eval.constants import (
16
+ ALL_DATASETS,
17
+ DATASET_PROPERTIES,
18
+ MED_LONG_DATASETS,
19
+ METRICS,
20
+ PRETTY_NAMES,
21
+ STANDARD_METRIC_NAMES,
22
+ )
23
+ from src.gift_eval.data import Dataset
24
+ from src.gift_eval.model_wrapper import TimeSeriesPredictor
25
+ from src.plotting.gift_eval_utils import create_plots_for_dataset
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Warnings configuration
30
+ warnings.filterwarnings("ignore", category=NumericalWarning)
31
+ warnings.filterwarnings("ignore", category=FutureWarning)
32
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
33
+ matplotlib.set_loglevel("WARNING")
34
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
35
+ logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
36
+ logging.getLogger("PIL").setLevel(logging.WARNING)
37
+
38
+
39
+ class WarningFilter(logging.Filter):
40
+ def __init__(self, text_to_filter: str) -> None:
41
+ super().__init__()
42
+ self.text_to_filter = text_to_filter
43
+
44
+ def filter(self, record: logging.LogRecord) -> bool:
45
+ return self.text_to_filter not in record.getMessage()
46
+
47
+
48
+ # Filter out gluonts warnings about mean predictions
49
+ gts_logger = logging.getLogger("gluonts.model.forecast")
50
+ gts_logger.addFilter(
51
+ WarningFilter("The mean prediction is not stored in the forecast data")
52
+ )
53
+
54
+
55
+ @dataclass
56
+ class DatasetMetadata:
57
+ full_name: str
58
+ key: str
59
+ freq: str
60
+ term: str
61
+ season_length: int
62
+ target_dim: int
63
+ to_univariate: bool
64
+ prediction_length: int
65
+ windows: int
66
+
67
+
68
+ @dataclass
69
+ class EvaluationItem:
70
+ dataset_metadata: DatasetMetadata
71
+ metrics: Dict
72
+ figures: List[Tuple[object, str]]
73
+
74
+
75
+ def construct_evaluation_data(
76
+ dataset_name: str,
77
+ dataset_storage_path: str,
78
+ terms: List[str] = ["short", "medium", "long"],
79
+ max_windows: Optional[int] = None,
80
+ ) -> List[Tuple[Dataset, DatasetMetadata]]:
81
+ """Build datasets and rich metadata per term for a dataset name."""
82
+ sub_datasets: List[Tuple[Dataset, DatasetMetadata]] = []
83
+
84
+ if "/" in dataset_name:
85
+ ds_key, ds_freq = dataset_name.split("/")
86
+ ds_key = ds_key.lower()
87
+ ds_key = PRETTY_NAMES.get(ds_key, ds_key)
88
+ else:
89
+ ds_key = dataset_name.lower()
90
+ ds_key = PRETTY_NAMES.get(ds_key, ds_key)
91
+ ds_freq = DATASET_PROPERTIES[ds_key]["frequency"]
92
+
93
+ for term in terms:
94
+ # Skip medium/long terms for datasets that don't support them
95
+ if (
96
+ term == "medium" or term == "long"
97
+ ) and dataset_name not in MED_LONG_DATASETS.split():
98
+ continue
99
+
100
+ # Probe once to determine dimensionality
101
+ probe_dataset = Dataset(
102
+ name=dataset_name,
103
+ term=term,
104
+ to_univariate=False,
105
+ storage_path=dataset_storage_path,
106
+ max_windows=max_windows,
107
+ )
108
+
109
+ to_univariate = probe_dataset.target_dim > 1
110
+
111
+ dataset = Dataset(
112
+ name=dataset_name,
113
+ term=term,
114
+ to_univariate=to_univariate,
115
+ storage_path=dataset_storage_path,
116
+ max_windows=max_windows,
117
+ )
118
+
119
+ # Compute metadata
120
+ season_length = get_seasonality(dataset.freq)
121
+ actual_freq = ds_freq if ds_freq else dataset.freq
122
+ metadata = DatasetMetadata(
123
+ full_name=f"{ds_key}/{actual_freq}/{term}",
124
+ key=ds_key,
125
+ freq=actual_freq,
126
+ term=term,
127
+ season_length=season_length,
128
+ target_dim=probe_dataset.target_dim,
129
+ to_univariate=to_univariate,
130
+ prediction_length=dataset.prediction_length,
131
+ windows=dataset.windows,
132
+ )
133
+
134
+ sub_datasets.append((dataset, metadata))
135
+
136
+ return sub_datasets
137
+
138
+
139
+ def _ensure_results_csv(csv_file_path: Path) -> None:
140
+ if not csv_file_path.exists():
141
+ csv_file_path.parent.mkdir(parents=True, exist_ok=True)
142
+ with open(csv_file_path, "w", newline="") as csvfile:
143
+ writer = csv.writer(csvfile)
144
+ header = (
145
+ ["dataset", "model"]
146
+ + [f"eval_metrics/{name}" for name in STANDARD_METRIC_NAMES]
147
+ + [
148
+ "domain",
149
+ "num_variates",
150
+ ]
151
+ )
152
+ writer.writerow(header)
153
+
154
+
155
+ def write_results_to_disk(
156
+ items: List[EvaluationItem],
157
+ dataset_name: str,
158
+ output_dir: Path,
159
+ model_name: str,
160
+ create_plots: bool,
161
+ ) -> None:
162
+ output_dir = output_dir / dataset_name
163
+ output_dir.mkdir(parents=True, exist_ok=True)
164
+ output_csv_path = output_dir / "results.csv"
165
+ _ensure_results_csv(output_csv_path)
166
+
167
+ with open(output_csv_path, "a", newline="") as csvfile:
168
+ writer = csv.writer(csvfile)
169
+ for item in items:
170
+ md = item.dataset_metadata
171
+ # Extract metric values in the standard order
172
+ metric_values: List[Optional[float]] = []
173
+ for metric_name in STANDARD_METRIC_NAMES:
174
+ value = item.metrics.get(metric_name, None)
175
+ if value is None:
176
+ metric_values.append(None)
177
+ else:
178
+ if (
179
+ hasattr(value, "__len__")
180
+ and not isinstance(value, (str, bytes))
181
+ and len(value) == 1
182
+ ):
183
+ value = value[0]
184
+ elif hasattr(value, "item"):
185
+ value = value.item()
186
+ metric_values.append(value)
187
+
188
+ # Lookup domain and num_variates from dataset properties
189
+ ds_key = md.key.lower()
190
+ props = DATASET_PROPERTIES.get(ds_key, {})
191
+ domain = props.get("domain", "unknown")
192
+ num_variates = props.get(
193
+ "num_variates", 1 if md.to_univariate else md.target_dim
194
+ )
195
+
196
+ row = [md.full_name, model_name] + metric_values + [domain, num_variates]
197
+ writer.writerow(row)
198
+
199
+ if create_plots and item.figures:
200
+ plots_dir = output_dir / "plots" / md.key / md.term
201
+ plots_dir.mkdir(parents=True, exist_ok=True)
202
+ for fig, filename in item.figures:
203
+ filepath = plots_dir / filename
204
+ fig.savefig(filepath, dpi=300, bbox_inches="tight")
205
+ plt.close(fig)
206
+
207
+ logger.info(
208
+ f"Evaluation complete for dataset '{dataset_name}'. Results saved to {output_csv_path}"
209
+ )
210
+ if create_plots:
211
+ logger.info(f"Plots saved under {output_dir / 'plots'}")
212
+
213
+
214
+ def evaluate_datasets(
215
+ predictor: TimeSeriesPredictor,
216
+ dataset: str,
217
+ dataset_storage_path: str,
218
+ terms: List[str] = ["short", "medium", "long"],
219
+ max_windows: Optional[int] = None,
220
+ batch_size: int = 48,
221
+ max_context_length: Optional[int] = 1024,
222
+ create_plots: bool = False,
223
+ max_plots_per_dataset: int = 10,
224
+ ) -> List[EvaluationItem]:
225
+ """Evaluate predictor on one dataset across the requested terms."""
226
+ sub_datasets = construct_evaluation_data(
227
+ dataset_name=dataset,
228
+ dataset_storage_path=dataset_storage_path,
229
+ terms=terms,
230
+ max_windows=max_windows,
231
+ )
232
+
233
+ results: List[EvaluationItem] = []
234
+ for i, (sub_dataset, metadata) in enumerate(sub_datasets):
235
+ logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}")
236
+ logger.info(f" Dataset size: {len(sub_dataset.test_data)}")
237
+ logger.info(f" Frequency: {sub_dataset.freq}")
238
+ logger.info(f" Term: {metadata.term}")
239
+ logger.info(f" Prediction length: {sub_dataset.prediction_length}")
240
+ logger.info(f" Target dimensions: {sub_dataset.target_dim}")
241
+ logger.info(f" Windows: {sub_dataset.windows}")
242
+
243
+ # Update context on the reusable predictor
244
+ predictor.set_dataset_context(
245
+ prediction_length=sub_dataset.prediction_length,
246
+ freq=sub_dataset.freq,
247
+ batch_size=batch_size,
248
+ max_context_length=max_context_length,
249
+ )
250
+
251
+ res = evaluate_model(
252
+ model=predictor,
253
+ test_data=sub_dataset.test_data,
254
+ metrics=METRICS,
255
+ axis=None,
256
+ mask_invalid_label=True,
257
+ allow_nan_forecast=False,
258
+ seasonality=metadata.season_length,
259
+ )
260
+
261
+ figs: List[Tuple[object, str]] = []
262
+ if create_plots:
263
+ forecasts = predictor.predict(sub_dataset.test_data.input)
264
+ figs = create_plots_for_dataset(
265
+ forecasts=forecasts,
266
+ test_data=sub_dataset.test_data,
267
+ dataset_metadata=metadata,
268
+ max_plots=max_plots_per_dataset,
269
+ max_context_length=max_context_length,
270
+ )
271
+
272
+ results.append(
273
+ EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs)
274
+ )
275
+
276
+ return results
277
+
278
+
279
+ def _expand_datasets_arg(datasets: List[str] | str) -> List[str]:
280
+ if datasets[0] == "all":
281
+ return list(ALL_DATASETS)
282
+ if isinstance(datasets, str):
283
+ datasets = [datasets]
284
+ for dataset in datasets:
285
+ if dataset not in ALL_DATASETS:
286
+ raise ValueError(f"Invalid dataset: {dataset}. Use one of {ALL_DATASETS}")
287
+ return datasets
288
+
289
+
290
+ def _run_evaluation(
291
+ predictor: TimeSeriesPredictor,
292
+ datasets: List[str] | str,
293
+ terms: List[str],
294
+ dataset_storage_path: str,
295
+ max_windows: Optional[int] = None,
296
+ batch_size: int = 48,
297
+ max_context_length: Optional[int] = 1024,
298
+ output_dir: str = "gift_eval_results",
299
+ model_name: str = "TimeSeriesModel",
300
+ create_plots: bool = False,
301
+ max_plots: int = 10,
302
+ ) -> None:
303
+ """Shared evaluation workflow used by both entry points."""
304
+ datasets_to_run = _expand_datasets_arg(datasets)
305
+ results_root = Path(output_dir)
306
+
307
+ for ds_name in datasets_to_run:
308
+ items = evaluate_datasets(
309
+ predictor=predictor,
310
+ dataset=ds_name,
311
+ dataset_storage_path=dataset_storage_path,
312
+ terms=terms,
313
+ max_windows=max_windows,
314
+ batch_size=batch_size,
315
+ max_context_length=max_context_length,
316
+ create_plots=create_plots,
317
+ max_plots_per_dataset=max_plots,
318
+ )
319
+ write_results_to_disk(
320
+ items=items,
321
+ dataset_name=ds_name,
322
+ output_dir=results_root,
323
+ model_name=model_name,
324
+ create_plots=create_plots,
325
+ )
326
+
327
+
328
+ def evaluate_from_paths(
329
+ model_path: str,
330
+ config_path: str,
331
+ datasets: List[str] | str,
332
+ terms: List[str],
333
+ dataset_storage_path: str,
334
+ max_windows: Optional[int] = None,
335
+ batch_size: int = 48,
336
+ max_context_length: Optional[int] = 1024,
337
+ output_dir: str = "gift_eval_results",
338
+ model_name: str = "TimeSeriesModel",
339
+ create_plots: bool = False,
340
+ max_plots: int = 10,
341
+ ) -> None:
342
+ """Entry point: load model from disk and save metrics/plots to disk."""
343
+ # Validate inputs early
344
+ if not Path(model_path).exists():
345
+ raise FileNotFoundError(f"Model path does not exist: {model_path}")
346
+ if not Path(config_path).exists():
347
+ raise FileNotFoundError(f"Config path does not exist: {config_path}")
348
+
349
+ predictor = TimeSeriesPredictor.from_paths(
350
+ model_path=model_path,
351
+ config_path=config_path,
352
+ ds_prediction_length=1, # placeholder; set per dataset below
353
+ ds_freq="D", # placeholder; set per dataset below
354
+ batch_size=batch_size,
355
+ max_context_length=max_context_length,
356
+ )
357
+
358
+ _run_evaluation(
359
+ predictor=predictor,
360
+ datasets=datasets,
361
+ terms=terms,
362
+ dataset_storage_path=dataset_storage_path,
363
+ max_windows=max_windows,
364
+ batch_size=batch_size,
365
+ max_context_length=max_context_length,
366
+ output_dir=output_dir,
367
+ model_name=model_name,
368
+ create_plots=create_plots,
369
+ max_plots=max_plots,
370
+ )
371
+
372
+
373
+ def evaluate_in_memory(
374
+ model,
375
+ config: dict,
376
+ datasets: List[str] | str,
377
+ terms: List[str],
378
+ dataset_storage_path: str,
379
+ max_windows: Optional[int] = None,
380
+ batch_size: int = 48,
381
+ max_context_length: Optional[int] = 1024,
382
+ output_dir: str = "gift_eval_results",
383
+ model_name: str = "TimeSeriesModel",
384
+ create_plots: bool = False,
385
+ max_plots: int = 10,
386
+ ) -> None:
387
+ """Entry point: evaluate in-memory model and return results per dataset."""
388
+ predictor = TimeSeriesPredictor.from_model(
389
+ model=model,
390
+ config=config,
391
+ ds_prediction_length=1, # placeholder; set per dataset below
392
+ ds_freq="D", # placeholder; set per dataset below
393
+ batch_size=batch_size,
394
+ max_context_length=max_context_length,
395
+ )
396
+
397
+ _run_evaluation(
398
+ predictor=predictor,
399
+ datasets=datasets,
400
+ terms=terms,
401
+ dataset_storage_path=dataset_storage_path,
402
+ max_windows=max_windows,
403
+ batch_size=batch_size,
404
+ max_context_length=max_context_length,
405
+ output_dir=output_dir,
406
+ model_name=model_name,
407
+ create_plots=create_plots,
408
+ max_plots=max_plots,
409
+ )
410
+
411
+
412
+ def _parse_args() -> argparse.Namespace:
413
+ parser = argparse.ArgumentParser(
414
+ description="Evaluate TimeSeriesModel on GIFT-Eval datasets"
415
+ )
416
+
417
+ # Model configuration
418
+ parser.add_argument(
419
+ "--model_path",
420
+ type=str,
421
+ required=True,
422
+ help="Path to the trained model checkpoint",
423
+ )
424
+ parser.add_argument(
425
+ "--config_path",
426
+ type=str,
427
+ required=True,
428
+ help="Path to the model configuration YAML file",
429
+ )
430
+ parser.add_argument(
431
+ "--model_name",
432
+ type=str,
433
+ default="TimeSeriesModel",
434
+ help="Name identifier for the model",
435
+ )
436
+
437
+ # Dataset configuration
438
+ parser.add_argument(
439
+ "--datasets",
440
+ type=str,
441
+ default="all",
442
+ help="Comma-separated list of dataset names to evaluate (or 'all')",
443
+ )
444
+ parser.add_argument(
445
+ "--dataset_storage_path",
446
+ type=str,
447
+ required=True,
448
+ help="Path to the dataset storage directory",
449
+ )
450
+ parser.add_argument(
451
+ "--terms",
452
+ type=str,
453
+ default="short,medium,long",
454
+ help="Comma-separated list of prediction terms to evaluate",
455
+ )
456
+ parser.add_argument(
457
+ "--max_windows",
458
+ type=int,
459
+ default=None,
460
+ help="Maximum number of windows to use for evaluation",
461
+ )
462
+
463
+ # Inference configuration
464
+ parser.add_argument(
465
+ "--batch_size", type=int, default=48, help="Batch size for model inference"
466
+ )
467
+ parser.add_argument(
468
+ "--max_context_length",
469
+ type=int,
470
+ default=1024,
471
+ help="Maximum context length to use (None for no limit)",
472
+ )
473
+
474
+ # Output configuration
475
+ parser.add_argument(
476
+ "--output_dir",
477
+ type=str,
478
+ default="gift_eval_results",
479
+ help="Directory to save evaluation results",
480
+ )
481
+
482
+ # Plotting configuration
483
+ parser.add_argument(
484
+ "--create_plots",
485
+ action="store_true",
486
+ help="Create and save plots for each evaluation window",
487
+ )
488
+ parser.add_argument(
489
+ "--max_plots_per_dataset",
490
+ type=int,
491
+ default=10,
492
+ help="Maximum number of plots to create per dataset term",
493
+ )
494
+
495
+ args = parser.parse_args()
496
+ args.terms = args.terms.split(",")
497
+ args.datasets = args.datasets.split(",")
498
+ return args
499
+
500
+
501
+ def _configure_logging() -> None:
502
+ logging.basicConfig(
503
+ level=logging.INFO,
504
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
505
+ )
506
+
507
+
508
+ if __name__ == "__main__":
509
+ _configure_logging()
510
+ args = _parse_args()
511
+ logger.info(f"Command Line Arguments: {vars(args)}")
512
+ try:
513
+ evaluate_from_paths(
514
+ model_path=args.model_path,
515
+ config_path=args.config_path,
516
+ datasets=args.datasets,
517
+ terms=args.terms,
518
+ dataset_storage_path=args.dataset_storage_path,
519
+ max_windows=args.max_windows,
520
+ batch_size=args.batch_size,
521
+ max_context_length=args.max_context_length,
522
+ output_dir=args.output_dir,
523
+ model_name=args.model_name,
524
+ create_plots=args.create_plots,
525
+ max_plots=args.max_plots_per_dataset,
526
+ )
527
+ except Exception as e:
528
+ logger.error(f"Evaluation failed: {str(e)}")
529
+ raise
src/gift_eval/model_wrapper.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Iterator, List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ import yaml
7
+ from gluonts.itertools import batcher
8
+ from gluonts.model.forecast import QuantileForecast
9
+ from gluonts.model.predictor import Predictor
10
+ from torch.nn.parallel import DistributedDataParallel as DDP
11
+
12
+ from src.data.containers import BatchTimeSeriesContainer
13
+ from src.data.frequency import parse_frequency
14
+ from src.data.scalers import RobustScaler
15
+ from src.models.model import TimeSeriesModel
16
+ from src.utils.utils import device
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class TimeSeriesPredictor(Predictor):
22
+ """
23
+ Unified predictor for TimeSeriesModel supporting both in-memory and file-based construction.
24
+
25
+ Use classmethods `from_model` and `from_paths` to construct instances.
26
+ Provides `set_dataset_context` to adjust dataset-specific parameters without reloading the model.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ model: TimeSeriesModel,
32
+ config: dict,
33
+ ds_prediction_length: int,
34
+ ds_freq: str,
35
+ batch_size: int = 32,
36
+ max_context_length: Optional[int] = None,
37
+ debug: bool = False,
38
+ ) -> None:
39
+ # Dataset-specific context (can be updated per dataset/term)
40
+ self.ds_prediction_length = ds_prediction_length
41
+ self.ds_freq = ds_freq
42
+ self.batch_size = batch_size
43
+ self.max_context_length = max_context_length
44
+ self.debug = debug
45
+
46
+ # Persistent model/config (unwrap DDP if needed)
47
+ self.model = model.module if isinstance(model, DDP) else model
48
+ self.model.eval()
49
+ self.config = config
50
+
51
+ # Initialize scaler (using same type as model)
52
+ scaler_type = self.config.get("TimeSeriesModel", {}).get(
53
+ "scaler", "custom_robust"
54
+ )
55
+ epsilon = self.config.get("TimeSeriesModel", {}).get("epsilon", 1e-3)
56
+ if scaler_type == "custom_robust":
57
+ self.scaler = RobustScaler(epsilon=epsilon)
58
+ else:
59
+ raise ValueError(f"Unsupported scaler type: {scaler_type}")
60
+
61
+ def set_dataset_context(
62
+ self,
63
+ prediction_length: Optional[int] = None,
64
+ freq: Optional[str] = None,
65
+ batch_size: Optional[int] = None,
66
+ max_context_length: Optional[int] = None,
67
+ ) -> None:
68
+ """Update lightweight dataset-specific attributes without reloading the model."""
69
+ if prediction_length is not None:
70
+ self.ds_prediction_length = prediction_length
71
+ if freq is not None:
72
+ self.ds_freq = freq
73
+ if batch_size is not None:
74
+ self.batch_size = batch_size
75
+ if max_context_length is not None:
76
+ self.max_context_length = max_context_length
77
+
78
+ @classmethod
79
+ def from_model(
80
+ cls,
81
+ model: TimeSeriesModel,
82
+ config: dict,
83
+ ds_prediction_length: int,
84
+ ds_freq: str,
85
+ batch_size: int = 32,
86
+ max_context_length: Optional[int] = None,
87
+ debug: bool = False,
88
+ ) -> "TimeSeriesPredictor":
89
+ return cls(
90
+ model=model,
91
+ config=config,
92
+ ds_prediction_length=ds_prediction_length,
93
+ ds_freq=ds_freq,
94
+ batch_size=batch_size,
95
+ max_context_length=max_context_length,
96
+ debug=debug,
97
+ )
98
+
99
+ @classmethod
100
+ def from_paths(
101
+ cls,
102
+ model_path: str,
103
+ config_path: str,
104
+ ds_prediction_length: int,
105
+ ds_freq: str,
106
+ batch_size: int = 32,
107
+ max_context_length: Optional[int] = None,
108
+ debug: bool = False,
109
+ ) -> "TimeSeriesPredictor":
110
+ with open(config_path, "r") as f:
111
+ config = yaml.safe_load(f)
112
+ model = cls._load_model_from_path(config=config, model_path=model_path)
113
+ return cls(
114
+ model=model,
115
+ config=config,
116
+ ds_prediction_length=ds_prediction_length,
117
+ ds_freq=ds_freq,
118
+ batch_size=batch_size,
119
+ max_context_length=max_context_length,
120
+ debug=debug,
121
+ )
122
+
123
+ @staticmethod
124
+ def _load_model_from_path(config: dict, model_path: str) -> TimeSeriesModel:
125
+ try:
126
+ model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
127
+ checkpoint = torch.load(model_path, map_location=device)
128
+ model.load_state_dict(checkpoint["model_state_dict"])
129
+ model.eval()
130
+ logger.info(f"Successfully loaded model from {model_path}")
131
+ return model
132
+ except Exception as e:
133
+ logger.error(f"Failed to load model from {model_path}: {str(e)}")
134
+ raise
135
+
136
+ def predict(self, test_data_input) -> Iterator[QuantileForecast]:
137
+ """
138
+ Generate forecasts for the test data.
139
+
140
+ Args:
141
+ test_data_input: Iterator of gluonts DataEntry objects
142
+
143
+ Returns:
144
+ Iterable of QuantileForecast objects
145
+ """
146
+ if hasattr(test_data_input, "__iter__") and not isinstance(
147
+ test_data_input, list
148
+ ):
149
+ test_data_input = list(test_data_input)
150
+ logger.debug(f"Processing {len(test_data_input)} time series")
151
+
152
+ # Group series by their effective length (after optional truncation),
153
+ # then process each uniform-length group in sub-batches up to batch_size.
154
+ def _effective_length(entry) -> int:
155
+ target = entry["target"]
156
+ if target.ndim == 1:
157
+ seq_len = len(target)
158
+ else:
159
+ # target shape is [num_channels, seq_len]
160
+ seq_len = target.shape[1]
161
+ if self.max_context_length is not None:
162
+ seq_len = min(seq_len, self.max_context_length)
163
+ return seq_len
164
+
165
+ length_to_items: dict[int, List[tuple[int, object]]] = {}
166
+ for idx, entry in enumerate(test_data_input):
167
+ L = _effective_length(entry)
168
+ length_to_items.setdefault(L, []).append((idx, entry))
169
+
170
+ total = len(test_data_input)
171
+ ordered_results: List[Optional[QuantileForecast]] = [None] * total
172
+
173
+ for _, items in length_to_items.items():
174
+ for i in range(0, len(items), self.batch_size):
175
+ chunk = items[i : i + self.batch_size]
176
+ entries = [e for (_orig_idx, e) in chunk]
177
+ batch_forecasts = self._predict_batch(entries)
178
+ for f_idx, (orig_idx, _e) in enumerate(chunk):
179
+ ordered_results[orig_idx] = batch_forecasts[f_idx]
180
+
181
+ # All results should be populated
182
+ return ordered_results # type: ignore[return-value]
183
+
184
+ def _predict_batch(self, test_data_batch: List) -> List[QuantileForecast]:
185
+ """Generate predictions for a batch of time series."""
186
+ logger.debug(f"Processing batch of size: {len(test_data_batch)}")
187
+
188
+ try:
189
+ # Convert gluonts data to BatchTimeSeriesContainer
190
+ batch_container = self._convert_to_batch_container(test_data_batch)
191
+
192
+ # Autocast only when running on CUDA
193
+ if isinstance(device, torch.device):
194
+ device_type = device.type
195
+ else:
196
+ device_type = "cuda" if "cuda" in str(device).lower() else "cpu"
197
+ enable_autocast = device_type == "cuda"
198
+
199
+ with torch.autocast(
200
+ device_type=device_type, dtype=torch.bfloat16, enabled=enable_autocast
201
+ ):
202
+ with torch.no_grad():
203
+ model_output = self.model(batch_container, drop_enc_allow=False)
204
+
205
+ # Convert predictions to QuantileForecast objects
206
+ forecasts = self._convert_to_forecasts(
207
+ model_output, test_data_batch, batch_container
208
+ )
209
+
210
+ logger.debug(f"Generated {len(forecasts)} forecasts")
211
+ return forecasts
212
+
213
+ except Exception as e:
214
+ logger.error(f"Error in batch prediction: {str(e)}")
215
+ raise
216
+
217
+ def _convert_to_batch_container(
218
+ self, test_data_batch: List
219
+ ) -> BatchTimeSeriesContainer:
220
+ """Convert gluonts test data to BatchTimeSeriesContainer."""
221
+ batch_size = len(test_data_batch)
222
+
223
+ # Extract data from test entries (all series in this batch must have equal length)
224
+ history_values_list = []
225
+ start_dates = []
226
+ frequencies = []
227
+
228
+ for entry in test_data_batch:
229
+ target = entry["target"]
230
+
231
+ # Handle both univariate and multivariate cases
232
+ if target.ndim == 1:
233
+ # Univariate: reshape to [seq_len, 1]
234
+ target = target.reshape(-1, 1)
235
+ else:
236
+ # Multivariate: assume shape is [num_channels, seq_len] -> transpose to [seq_len, num_channels]
237
+ target = target.T
238
+
239
+ # Apply context length limit if specified
240
+ if (
241
+ self.max_context_length is not None
242
+ and len(target) > self.max_context_length
243
+ ):
244
+ target = target[-self.max_context_length :]
245
+
246
+ history_values_list.append(target)
247
+ start_dates.append(entry["start"].to_timestamp().to_datetime64())
248
+ frequencies.append(parse_frequency(entry["freq"]))
249
+
250
+ # Stack sequences directly (no padding) -> shapes are uniform by construction
251
+ history_values_np = np.stack(history_values_list, axis=0)
252
+ num_channels = history_values_np.shape[2]
253
+
254
+ # Convert to tensors
255
+ history_values = torch.tensor(
256
+ history_values_np, dtype=torch.float32, device=device
257
+ )
258
+
259
+ # Create future values tensor (empty for prediction)
260
+ future_values = torch.zeros(
261
+ (batch_size, self.ds_prediction_length, num_channels),
262
+ dtype=torch.float32,
263
+ device=device,
264
+ )
265
+
266
+ return BatchTimeSeriesContainer(
267
+ history_values=history_values,
268
+ future_values=future_values,
269
+ start=start_dates,
270
+ frequency=frequencies,
271
+ )
272
+
273
+ def _convert_to_forecasts(
274
+ self,
275
+ model_output: dict,
276
+ test_data_batch: List,
277
+ batch_container: BatchTimeSeriesContainer,
278
+ ) -> List[QuantileForecast]:
279
+ """Convert model predictions to QuantileForecast objects."""
280
+ predictions = model_output[
281
+ "result"
282
+ ] # Shape: [batch_size, pred_len, num_channels] or [batch_size, pred_len, num_channels, num_quantiles]
283
+ scale_statistics = model_output["scale_statistics"]
284
+
285
+ # Apply inverse scaling to get predictions in original scale
286
+ if predictions.ndim == 4:
287
+ # Quantile predictions: [batch_size, pred_len, num_channels, num_quantiles]
288
+ predictions_unscaled = self.scaler.inverse_scale(
289
+ predictions, scale_statistics
290
+ )
291
+ is_quantile = True
292
+ quantile_levels = self.model.quantiles
293
+ else:
294
+ # Point predictions: [batch_size, pred_len, num_channels]
295
+ predictions_unscaled = self.scaler.inverse_scale(
296
+ predictions, scale_statistics
297
+ )
298
+ is_quantile = False
299
+ quantile_levels = [0.5] # Treat as median forecast
300
+
301
+ forecasts: List[QuantileForecast] = []
302
+ for i, entry in enumerate(test_data_batch):
303
+ # Get prediction start date
304
+ history_length = int(batch_container.history_values.shape[1])
305
+ start_date = entry["start"]
306
+ forecast_start = start_date + history_length
307
+
308
+ if is_quantile:
309
+ # Handle quantile forecasts
310
+ pred_array = (
311
+ predictions_unscaled[i].cpu().numpy()
312
+ ) # [pred_len, num_channels, num_quantiles]
313
+
314
+ if pred_array.shape[1] == 1:
315
+ # Univariate case: [pred_len, 1, num_quantiles] -> [pred_len, num_quantiles]
316
+ pred_array = pred_array.squeeze(1)
317
+ forecast_arrays = pred_array.T # [num_quantiles, pred_len]
318
+ else:
319
+ # Multivariate case: [pred_len, num_channels, num_quantiles] -> [num_quantiles, pred_len, num_channels]
320
+ forecast_arrays = pred_array.transpose(2, 0, 1)
321
+
322
+ forecast = QuantileForecast(
323
+ forecast_arrays=forecast_arrays,
324
+ forecast_keys=[str(q) for q in quantile_levels],
325
+ start_date=forecast_start,
326
+ )
327
+ else:
328
+ # Handle point forecasts
329
+ pred_array = (
330
+ predictions_unscaled[i].cpu().numpy()
331
+ ) # [pred_len, num_channels]
332
+
333
+ if pred_array.shape[1] == 1:
334
+ # Univariate case: [pred_len, 1] -> [pred_len]
335
+ pred_array = pred_array.squeeze(1)
336
+ forecast_arrays = pred_array.reshape(1, -1) # [1, pred_len]
337
+ else:
338
+ # Multivariate case: [pred_len, num_channels] -> [1, pred_len, num_channels]
339
+ forecast_arrays = pred_array.reshape(1, *pred_array.shape)
340
+
341
+ forecast = QuantileForecast(
342
+ forecast_arrays=forecast_arrays,
343
+ forecast_keys=["0.5"],
344
+ start_date=forecast_start,
345
+ )
346
+
347
+ forecasts.append(forecast)
348
+
349
+ return forecasts
src/models/__init__.py ADDED
File without changes
src/models/blocks.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fla.models.gated_deltaproduct import GatedDeltaProductConfig
4
+ from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductBlock
5
+
6
+
7
+ class GatedDeltaProductEncoder(nn.Module):
8
+ """
9
+ GatedDeltaNet encoder using GatedDeltaProductBlock for sequence modeling.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ layer_idx: int,
15
+ token_embed_dim: int,
16
+ num_heads: int = 4,
17
+ attn_mode: str = "chunk",
18
+ expand_v: float = 1.0,
19
+ use_gate: bool = False,
20
+ use_short_conv: bool = True,
21
+ conv_size: int = 4,
22
+ allow_neg_eigval: bool = True,
23
+ use_forget_gate: bool = True,
24
+ num_householder: int = 1,
25
+ **kwargs,
26
+ ):
27
+ super().__init__()
28
+ config = GatedDeltaProductConfig(
29
+ attn_mode=attn_mode,
30
+ hidden_size=token_embed_dim,
31
+ expand_v=expand_v,
32
+ use_gate=use_gate,
33
+ use_short_conv=use_short_conv,
34
+ conv_size=conv_size,
35
+ head_dim=token_embed_dim // num_heads,
36
+ hidden_ratio=0.5,
37
+ num_heads=num_heads,
38
+ allow_neg_eigval=allow_neg_eigval,
39
+ use_forget_gate=use_forget_gate,
40
+ num_householder=num_householder,
41
+ )
42
+
43
+ self.encoder_layer = GatedDeltaProductBlock(layer_idx=layer_idx, config=config)
44
+
45
+ def forward(self, x, initial_state=None):
46
+ """
47
+ Forward pass through the GatedDeltaProductBlock.
48
+
49
+ Args:
50
+ x: Input tensor of shape [batch_size, seq_len, hidden_size]
51
+
52
+ Returns:
53
+ Output tensor of same shape as input
54
+ """
55
+ x, last_hidden_state, _ = self.encoder_layer(
56
+ x, output_attentions=True, initial_state=initial_state
57
+ )
58
+ return x, last_hidden_state
src/models/model.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from fla.modules import GatedMLP
4
+
5
+ from src.data.containers import BatchTimeSeriesContainer
6
+ from src.data.scalers import MinMaxScaler, RobustScaler
7
+ from src.data.time_features import compute_batch_time_features
8
+ from src.models.blocks import GatedDeltaProductEncoder
9
+ from src.utils.utils import device
10
+
11
+
12
+ def create_scaler(scaler_type: str, epsilon: float = 1e-3):
13
+ """Create scaler instance based on type."""
14
+ if scaler_type == "custom_robust":
15
+ return RobustScaler(epsilon=epsilon)
16
+ elif scaler_type == "min_max":
17
+ return MinMaxScaler(epsilon=epsilon)
18
+ else:
19
+ raise ValueError(f"Unknown scaler: {scaler_type}")
20
+
21
+
22
+ def apply_channel_noise(values: torch.Tensor, noise_scale: float = 0.1):
23
+ """Add noise to constant channels to prevent model instability."""
24
+ is_constant = torch.all(values == values[:, 0:1, :], dim=1)
25
+ noise = torch.randn_like(values) * noise_scale * is_constant.unsqueeze(1)
26
+ return values + noise
27
+
28
+
29
+ class TimeSeriesModel(nn.Module):
30
+ """Time series forecasting model combining embedding, encoding, and prediction."""
31
+
32
+ def __init__(
33
+ self,
34
+ # Core architecture
35
+ embed_size: int = 128,
36
+ num_encoder_layers: int = 2,
37
+ # Scaling and preprocessing
38
+ scaler: str = "custom_robust",
39
+ epsilon: float = 1e-3,
40
+ scaler_clamp_value: float = None,
41
+ handle_constants: bool = False,
42
+ # Time features
43
+ K_max: int = 6,
44
+ time_feature_config: dict = None,
45
+ encoding_dropout: float = 0.0,
46
+ # Encoder configuration
47
+ encoder_config: dict = None,
48
+ # Loss configuration
49
+ loss_type: str = "huber", # "huber", "quantile"
50
+ quantiles: list[float] = None,
51
+ **kwargs,
52
+ ):
53
+ super().__init__()
54
+
55
+ # Core parameters
56
+ self.embed_size = embed_size
57
+ self.num_encoder_layers = num_encoder_layers
58
+ self.epsilon = epsilon
59
+ self.scaler_clamp_value = scaler_clamp_value
60
+ self.handle_constants = handle_constants
61
+ self.encoding_dropout = encoding_dropout
62
+ self.K_max = K_max
63
+ self.time_feature_config = time_feature_config or {}
64
+ self.encoder_config = encoder_config or {}
65
+
66
+ # Store loss parameters
67
+ self.loss_type = loss_type
68
+ self.quantiles = quantiles
69
+ if self.loss_type == "quantile" and self.quantiles is None:
70
+ raise ValueError("Quantiles must be provided for quantile loss.")
71
+ if self.quantiles:
72
+ self.register_buffer(
73
+ "qt", torch.tensor(self.quantiles, device=device).view(1, 1, 1, -1)
74
+ )
75
+
76
+ # Validate configuration before initialization
77
+ self._validate_configuration()
78
+
79
+ # Initialize components
80
+ self.scaler = create_scaler(scaler, epsilon)
81
+ self._init_embedding_layers()
82
+ self._init_encoder_layers(self.encoder_config, num_encoder_layers)
83
+ self._init_projection_layers()
84
+
85
+ def _validate_configuration(self):
86
+ """Validate essential model configuration parameters."""
87
+ if "num_heads" not in self.encoder_config:
88
+ raise ValueError("encoder_config must contain 'num_heads' parameter")
89
+
90
+ if self.embed_size % self.encoder_config["num_heads"] != 0:
91
+ raise ValueError(
92
+ f"embed_size ({self.embed_size}) must be divisible by "
93
+ f"num_heads ({self.encoder_config['num_heads']})"
94
+ )
95
+
96
+ def _init_embedding_layers(self):
97
+ """Initialize value and time feature embedding layers."""
98
+ self.expand_values = nn.Linear(1, self.embed_size, bias=True)
99
+ self.nan_embedding = nn.Parameter(
100
+ torch.randn(1, 1, 1, self.embed_size) / self.embed_size,
101
+ requires_grad=True,
102
+ )
103
+ self.time_feature_projection = nn.Linear(self.K_max, self.embed_size)
104
+
105
+ def _init_encoder_layers(self, encoder_config: dict, num_encoder_layers: int):
106
+ """Initialize encoder layers."""
107
+ self.num_encoder_layers = num_encoder_layers
108
+
109
+ # Ensure encoder_config has token_embed_dim
110
+ encoder_config = encoder_config.copy()
111
+ encoder_config["token_embed_dim"] = self.embed_size
112
+ self.encoder_layers = nn.ModuleList(
113
+ [
114
+ GatedDeltaProductEncoder(layer_idx=layer_idx, **encoder_config)
115
+ for layer_idx in range(self.num_encoder_layers)
116
+ ]
117
+ )
118
+
119
+ def _init_projection_layers(self):
120
+ if self.loss_type == "quantile":
121
+ output_dim = len(self.quantiles)
122
+ else:
123
+ output_dim = 1
124
+ self.final_output_layer = nn.Linear(self.embed_size, output_dim)
125
+
126
+ self.mlp = GatedMLP(
127
+ hidden_size=self.embed_size,
128
+ hidden_ratio=4,
129
+ hidden_act="swish",
130
+ fuse_swiglu=True,
131
+ )
132
+ # Initialize learnable initial hidden state for the first encoder layer
133
+ # This will be expanded to match batch size during forward pass
134
+ head_k_dim = self.embed_size // self.encoder_config["num_heads"]
135
+
136
+ # Get expand_v from encoder_config, default to 1.0 if not present
137
+ expand_v = self.encoder_config.get("expand_v", 1.0)
138
+ head_v_dim = int(head_k_dim * expand_v)
139
+
140
+ num_initial_hidden_states = self.num_encoder_layers
141
+ self.initial_hidden_state = nn.ParameterList(
142
+ [
143
+ nn.Parameter(
144
+ torch.randn(
145
+ 1, self.encoder_config["num_heads"], head_k_dim, head_v_dim
146
+ )
147
+ / head_k_dim,
148
+ requires_grad=True,
149
+ )
150
+ for _ in range(num_initial_hidden_states)
151
+ ]
152
+ )
153
+
154
+ def _preprocess_data(self, data_container: BatchTimeSeriesContainer):
155
+ """Extract data shapes and handle constants without padding."""
156
+ history_values = data_container.history_values
157
+ future_values = data_container.future_values
158
+ history_mask = data_container.history_mask
159
+
160
+ batch_size, history_length, num_channels = history_values.shape
161
+ future_length = future_values.shape[1] if future_values is not None else 0
162
+
163
+ # Handle constants
164
+ if self.handle_constants:
165
+ history_values = apply_channel_noise(history_values)
166
+
167
+ return {
168
+ "history_values": history_values,
169
+ "future_values": future_values,
170
+ "history_mask": history_mask,
171
+ "num_channels": num_channels,
172
+ "history_length": history_length,
173
+ "future_length": future_length,
174
+ "batch_size": batch_size,
175
+ }
176
+
177
+ def _compute_scaling(
178
+ self, history_values: torch.Tensor, history_mask: torch.Tensor = None
179
+ ):
180
+ """Compute scaling statistics and apply scaling."""
181
+ scale_statistics = self.scaler.compute_statistics(history_values, history_mask)
182
+ return scale_statistics
183
+
184
+ def _apply_scaling_and_masking(
185
+ self, values: torch.Tensor, scale_statistics: dict, mask: torch.Tensor = None
186
+ ):
187
+ """Apply scaling and optional masking to values."""
188
+ scaled_values = self.scaler.scale(values, scale_statistics)
189
+
190
+ if mask is not None:
191
+ scaled_values = scaled_values * mask.unsqueeze(-1).float()
192
+
193
+ if self.scaler_clamp_value is not None:
194
+ scaled_values = torch.clamp(
195
+ scaled_values, -self.scaler_clamp_value, self.scaler_clamp_value
196
+ )
197
+
198
+ return scaled_values
199
+
200
+ def _get_positional_embeddings(
201
+ self,
202
+ time_features: torch.Tensor,
203
+ num_channels: int,
204
+ batch_size: int,
205
+ drop_enc_allow: bool = False,
206
+ ):
207
+ """Generate positional embeddings from time features."""
208
+ seq_len = time_features.shape[1]
209
+
210
+ if (torch.rand(1).item() < self.encoding_dropout) and drop_enc_allow:
211
+ return torch.zeros(
212
+ batch_size, seq_len, num_channels, self.embed_size, device=device
213
+ ).to(torch.float32)
214
+
215
+ pos_embed = self.time_feature_projection(time_features)
216
+ return pos_embed.unsqueeze(2).expand(-1, -1, num_channels, -1)
217
+
218
+ def _compute_embeddings(
219
+ self,
220
+ scaled_history: torch.Tensor,
221
+ history_pos_embed: torch.Tensor,
222
+ history_mask: torch.Tensor | None = None,
223
+ ):
224
+ """Compute value embeddings and combine with positional embeddings."""
225
+
226
+ nan_mask = torch.isnan(scaled_history)
227
+ history_for_embedding = torch.nan_to_num(scaled_history, nan=0.0)
228
+ channel_embeddings = self.expand_values(history_for_embedding.unsqueeze(-1))
229
+ channel_embeddings[nan_mask] = self.nan_embedding.to(channel_embeddings.dtype)
230
+ channel_embeddings = channel_embeddings + history_pos_embed
231
+
232
+ # Suppress padded time steps completely so padding is a pure batching artifact
233
+ # history_mask: [B, S] -> broadcast to [B, S, 1, 1]
234
+ if history_mask is not None:
235
+ mask_broadcast = (
236
+ history_mask.unsqueeze(-1).unsqueeze(-1).to(channel_embeddings.dtype)
237
+ )
238
+ channel_embeddings = channel_embeddings * mask_broadcast
239
+
240
+ batch_size, seq_len = scaled_history.shape[:2]
241
+ all_channels_embedded = channel_embeddings.view(batch_size, seq_len, -1)
242
+
243
+ return all_channels_embedded
244
+
245
+ def _generate_predictions(
246
+ self,
247
+ embedded: torch.Tensor,
248
+ target_pos_embed: torch.Tensor,
249
+ prediction_length: int,
250
+ num_channels: int,
251
+ history_mask: torch.Tensor = None,
252
+ ):
253
+ """
254
+ Generate predictions for all channels using vectorized operations.
255
+ """
256
+ batch_size, seq_len, _ = embedded.shape
257
+ # embedded shape: [B, S, N*E] -> Reshape to [B, S, N, E]
258
+ embedded = embedded.view(batch_size, seq_len, num_channels, self.embed_size)
259
+
260
+ # Vectorize across channels by merging the batch and channel dimensions.
261
+ # [B, S, N, E] -> [B*N, S, E]
262
+ channel_embedded = (
263
+ embedded.permute(0, 2, 1, 3)
264
+ .contiguous()
265
+ .view(batch_size * num_channels, seq_len, self.embed_size)
266
+ )
267
+
268
+ # Reshape target positional embeddings similarly: [B, P, N, E] -> [B*N, P, E]
269
+ target_pos_embed = (
270
+ target_pos_embed.permute(0, 2, 1, 3)
271
+ .contiguous()
272
+ .view(batch_size * num_channels, prediction_length, self.embed_size)
273
+ )
274
+ x = channel_embedded
275
+ target_repr = target_pos_embed
276
+ x = torch.concatenate([x, target_repr], dim=1)
277
+ if self.encoder_config.get("weaving", True):
278
+ # initial hidden state is learnable
279
+ hidden_state = torch.zeros_like(
280
+ self.initial_hidden_state[0].repeat(batch_size * num_channels, 1, 1, 1)
281
+ )
282
+ for layer_idx, encoder_layer in enumerate(self.encoder_layers):
283
+ x, hidden_state = encoder_layer(
284
+ x,
285
+ hidden_state
286
+ + self.initial_hidden_state[layer_idx].repeat(
287
+ batch_size * num_channels, 1, 1, 1
288
+ ),
289
+ )
290
+ else:
291
+ # initial hidden state is separately learnable for each layer
292
+ for layer_idx, encoder_layer in enumerate(self.encoder_layers):
293
+ initial_hidden_state = self.initial_hidden_state[layer_idx].repeat(
294
+ batch_size * num_channels, 1, 1, 1
295
+ )
296
+ x, _ = encoder_layer(x, initial_hidden_state)
297
+
298
+ # Use the last prediction_length positions
299
+ prediction_embeddings = x[:, -prediction_length:, :]
300
+
301
+ predictions = self.final_output_layer(self.mlp(prediction_embeddings))
302
+
303
+ # Reshape output to handle quantiles
304
+ # Original shape: [B*N, P, Q] where Q is num_quantiles or 1
305
+ # Reshape the output back to [B, P, N, Q]
306
+ output_dim = len(self.quantiles) if self.loss_type == "quantile" else 1
307
+ predictions = predictions.view(
308
+ batch_size, num_channels, prediction_length, output_dim
309
+ )
310
+ predictions = predictions.permute(0, 2, 1, 3) # [B, P, N, Q]
311
+ # Squeeze the last dimension if not in quantile mode for backward compatibility
312
+ if self.loss_type != "quantile":
313
+ predictions = predictions.squeeze(-1) # [B, P, N]
314
+ return predictions
315
+
316
+ def forward(
317
+ self, data_container: BatchTimeSeriesContainer, drop_enc_allow: bool = False
318
+ ):
319
+ """Main forward pass."""
320
+ # Preprocess data
321
+ preprocessed = self._preprocess_data(data_container)
322
+
323
+ # Compute time features dynamically based on actual lengths
324
+ history_time_features, target_time_features = compute_batch_time_features(
325
+ start=data_container.start,
326
+ history_length=preprocessed["history_length"],
327
+ future_length=preprocessed["future_length"],
328
+ batch_size=preprocessed["batch_size"],
329
+ frequency=data_container.frequency,
330
+ K_max=self.K_max,
331
+ time_feature_config=self.time_feature_config,
332
+ )
333
+
334
+ # Compute scaling
335
+ scale_statistics = self._compute_scaling(
336
+ preprocessed["history_values"], preprocessed["history_mask"]
337
+ )
338
+
339
+ # Apply scaling
340
+ history_scaled = self._apply_scaling_and_masking(
341
+ preprocessed["history_values"],
342
+ scale_statistics,
343
+ preprocessed["history_mask"],
344
+ )
345
+
346
+ # Scale future values if present
347
+ future_scaled = None
348
+ if preprocessed["future_values"] is not None:
349
+ future_scaled = self.scaler.scale(
350
+ preprocessed["future_values"], scale_statistics
351
+ )
352
+
353
+ # Get positional embeddings
354
+ history_pos_embed = self._get_positional_embeddings(
355
+ history_time_features,
356
+ preprocessed["num_channels"],
357
+ preprocessed["batch_size"],
358
+ drop_enc_allow,
359
+ )
360
+ target_pos_embed = self._get_positional_embeddings(
361
+ target_time_features,
362
+ preprocessed["num_channels"],
363
+ preprocessed["batch_size"],
364
+ drop_enc_allow,
365
+ )
366
+
367
+ # Compute embeddings
368
+ history_embed = self._compute_embeddings(
369
+ history_scaled, history_pos_embed, preprocessed["history_mask"]
370
+ )
371
+
372
+ # Generate predictions
373
+ predictions = self._generate_predictions(
374
+ history_embed,
375
+ target_pos_embed,
376
+ preprocessed["future_length"],
377
+ preprocessed["num_channels"],
378
+ preprocessed["history_mask"],
379
+ )
380
+
381
+ return {
382
+ "result": predictions,
383
+ "scale_statistics": scale_statistics,
384
+ "future_scaled": future_scaled,
385
+ "history_length": preprocessed["history_length"],
386
+ "future_length": preprocessed["future_length"],
387
+ }
388
+
389
+ def _quantile_loss(self, y_true: torch.Tensor, y_pred: torch.Tensor):
390
+ """
391
+ Compute the quantile loss.
392
+ y_true: [B, P, N]
393
+ y_pred: [B, P, N, Q]
394
+ """
395
+ # Add a dimension to y_true to match y_pred: [B, P, N] -> [B, P, N, 1]
396
+ y_true = y_true.unsqueeze(-1)
397
+
398
+ # Calculate errors
399
+ errors = y_true - y_pred
400
+
401
+ # Calculate quantile loss
402
+ # The max operator implements the two cases of the quantile loss formula
403
+ loss = torch.max((self.qt - 1) * errors, self.qt * errors)
404
+
405
+ # Average the loss across all dimensions
406
+ return loss.mean()
407
+
408
+ def compute_loss(self, y_true: torch.Tensor, y_pred: dict):
409
+ """Compute loss between predictions and scaled ground truth."""
410
+ predictions = y_pred["result"]
411
+ scale_statistics = y_pred["scale_statistics"]
412
+
413
+ if y_true is None:
414
+ return torch.tensor(0.0, device=predictions.device)
415
+
416
+ future_scaled = self.scaler.scale(y_true, scale_statistics)
417
+
418
+ if self.loss_type == "huber":
419
+ if predictions.shape != future_scaled.shape:
420
+ raise ValueError(
421
+ f"Shape mismatch for Huber loss: predictions {predictions.shape} vs future_scaled {future_scaled.shape}"
422
+ )
423
+ return nn.functional.huber_loss(predictions, future_scaled)
424
+ elif self.loss_type == "quantile":
425
+ return self._quantile_loss(future_scaled, predictions)
426
+ else:
427
+ raise ValueError(f"Unknown loss type: {self.loss_type}")
src/optim/lr_scheduler.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/utils/lr_scheduler.py
2
+
3
+ import math
4
+ from enum import Enum
5
+ from functools import partial
6
+ from typing import Optional
7
+
8
+ from torch.optim import Optimizer
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+
11
+
12
+ class SchedulerType(Enum):
13
+ """Enumeration of available learning rate schedulers."""
14
+
15
+ COSINE = "cosine"
16
+ COSINE_WITH_WARMUP = "cosine_with_warmup"
17
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
18
+ WARMUP_STABLE_DECAY = "warmup_stable_decay"
19
+ POLYNOMIAL_WITH_WARMUP = "polynomial_with_warmup"
20
+ LINEAR_WITH_WARMUP = "linear_with_warmup"
21
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
22
+ INVERSE_SQRT = "inverse_sqrt"
23
+
24
+
25
+ def _get_warmup_stable_decay_lr_lambda(
26
+ current_step: int,
27
+ *,
28
+ num_warmup_steps: int,
29
+ num_stable_steps: int,
30
+ num_training_steps: int,
31
+ min_lr_ratio: float = 0.001,
32
+ decay_type: str = "cosine",
33
+ ):
34
+ """
35
+ Learning rate lambda function for Warmup-Stable-Decay (WSD) schedule.
36
+
37
+ This scheduler implements three phases:
38
+ 1. Warmup: Linear increase from 0 to peak learning rate
39
+ 2. Stable: Constant learning rate for majority of training
40
+ 3. Decay: Gradual decrease using cosine or linear decay
41
+
42
+ Args:
43
+ current_step: Current training step
44
+ num_warmup_steps: Number of warmup steps
45
+ num_stable_steps: Number of stable learning rate steps
46
+ num_training_steps: Total number of training steps
47
+ min_lr_ratio: Minimum learning rate as ratio of peak learning rate
48
+ decay_type: Type of decay schedule ("cosine" or "linear")
49
+ """
50
+ if current_step < num_warmup_steps:
51
+ # Warmup phase: linear increase
52
+ return float(current_step) / float(max(1, num_warmup_steps))
53
+
54
+ elif current_step < num_warmup_steps + num_stable_steps:
55
+ # Stable phase: constant learning rate
56
+ return 1.0
57
+
58
+ else:
59
+ # Decay phase
60
+ decay_steps = num_training_steps - num_warmup_steps - num_stable_steps
61
+ if decay_steps <= 0:
62
+ return max(min_lr_ratio, 1.0)
63
+
64
+ progress = (current_step - num_warmup_steps - num_stable_steps) / decay_steps
65
+ progress = min(progress, 1.0) # Clamp to [0, 1]
66
+
67
+ if decay_type == "cosine":
68
+ # Cosine decay
69
+ decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
70
+ return max(min_lr_ratio, decay_factor)
71
+ elif decay_type == "linear":
72
+ # Linear decay
73
+ decay_factor = 1.0 - progress
74
+ return max(min_lr_ratio, decay_factor)
75
+ else:
76
+ raise ValueError(f"Unknown decay_type: {decay_type}")
77
+
78
+
79
+ def get_warmup_stable_decay_schedule(
80
+ optimizer: Optimizer,
81
+ num_warmup_steps: int,
82
+ num_stable_steps: int,
83
+ num_training_steps: int,
84
+ min_lr_ratio: float = 0.01,
85
+ decay_type: str = "cosine",
86
+ last_epoch: int = -1,
87
+ ):
88
+ """
89
+ Create a Warmup-Stable-Decay learning rate schedule.
90
+
91
+ This scheduler is particularly well-suited for foundation model training as it:
92
+ - Provides stable learning during the majority of training
93
+ - Doesn't require pre-committing to exact training duration
94
+ - Allows for extended training without aggressive decay
95
+
96
+ Args:
97
+ optimizer: The optimizer for which to schedule the learning rate
98
+ num_warmup_steps: Number of steps for warmup phase
99
+ num_stable_steps: Number of steps for stable learning rate phase
100
+ num_training_steps: Total number of training steps
101
+ min_lr_ratio: Minimum learning rate as fraction of peak learning rate
102
+ decay_type: Type of decay ("cosine" or "linear")
103
+ last_epoch: The index of the last epoch when resuming training
104
+
105
+ Returns:
106
+ torch.optim.lr_scheduler.LambdaLR with the WSD schedule
107
+ """
108
+ lr_lambda = partial(
109
+ _get_warmup_stable_decay_lr_lambda,
110
+ num_warmup_steps=num_warmup_steps,
111
+ num_stable_steps=num_stable_steps,
112
+ num_training_steps=num_training_steps,
113
+ min_lr_ratio=min_lr_ratio,
114
+ decay_type=decay_type,
115
+ )
116
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
117
+
118
+
119
+ def _get_cosine_schedule_with_warmup_lr_lambda(
120
+ current_step: int,
121
+ *,
122
+ num_warmup_steps: int,
123
+ num_training_steps: int,
124
+ num_cycles: float = 0.5,
125
+ min_lr_ratio: float = 0.0,
126
+ ):
127
+ """Enhanced cosine schedule with configurable minimum learning rate."""
128
+ if current_step < num_warmup_steps:
129
+ return float(current_step) / float(max(1, num_warmup_steps))
130
+
131
+ progress = float(current_step - num_warmup_steps) / float(
132
+ max(1, num_training_steps - num_warmup_steps)
133
+ )
134
+ cosine_factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
135
+ return max(min_lr_ratio, cosine_factor)
136
+
137
+
138
+ def get_enhanced_cosine_schedule_with_warmup(
139
+ optimizer: Optimizer,
140
+ num_warmup_steps: int,
141
+ num_training_steps: int,
142
+ num_cycles: float = 0.5,
143
+ min_lr_ratio: float = 0.01,
144
+ last_epoch: int = -1,
145
+ ):
146
+ """
147
+ Enhanced cosine schedule with warmup and configurable minimum learning rate.
148
+
149
+ Args:
150
+ optimizer: The optimizer for which to schedule the learning rate
151
+ num_warmup_steps: Number of steps for warmup phase
152
+ num_training_steps: Total number of training steps
153
+ num_cycles: Number of cosine cycles (0.5 = half cosine)
154
+ min_lr_ratio: Minimum learning rate as fraction of peak learning rate
155
+ last_epoch: The index of the last epoch when resuming training
156
+ """
157
+ lr_lambda = partial(
158
+ _get_cosine_schedule_with_warmup_lr_lambda,
159
+ num_warmup_steps=num_warmup_steps,
160
+ num_training_steps=num_training_steps,
161
+ num_cycles=num_cycles,
162
+ min_lr_ratio=min_lr_ratio,
163
+ )
164
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
165
+
166
+
167
+ def _get_cosine_with_restarts_lr_lambda(
168
+ current_step: int,
169
+ *,
170
+ num_warmup_steps: int,
171
+ num_training_steps: int,
172
+ num_cycles: int = 1,
173
+ min_lr_ratio: float = 0.0,
174
+ ):
175
+ """Cosine schedule with hard restarts and configurable minimum learning rate."""
176
+ if current_step < num_warmup_steps:
177
+ return float(current_step) / float(max(1, num_warmup_steps))
178
+
179
+ progress = float(current_step - num_warmup_steps) / float(
180
+ max(1, num_training_steps - num_warmup_steps)
181
+ )
182
+ if progress >= 1.0:
183
+ return min_lr_ratio
184
+
185
+ cosine_factor = 0.5 * (
186
+ 1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))
187
+ )
188
+ return max(min_lr_ratio, cosine_factor)
189
+
190
+
191
+ def get_cosine_with_restarts_schedule(
192
+ optimizer: Optimizer,
193
+ num_warmup_steps: int,
194
+ num_training_steps: int,
195
+ num_cycles: int = 4,
196
+ min_lr_ratio: float = 0.01,
197
+ last_epoch: int = -1,
198
+ ):
199
+ """
200
+ Cosine schedule with hard restarts.
201
+
202
+ Args:
203
+ optimizer: The optimizer for which to schedule the learning rate
204
+ num_warmup_steps: Number of steps for warmup phase
205
+ num_training_steps: Total number of training steps
206
+ num_cycles: Number of restart cycles
207
+ min_lr_ratio: Minimum learning rate as fraction of peak learning rate
208
+ last_epoch: The index of the last epoch when resuming training
209
+ """
210
+ lr_lambda = partial(
211
+ _get_cosine_with_restarts_lr_lambda,
212
+ num_warmup_steps=num_warmup_steps,
213
+ num_training_steps=num_training_steps,
214
+ num_cycles=num_cycles,
215
+ min_lr_ratio=min_lr_ratio,
216
+ )
217
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
218
+
219
+
220
+ # Scheduler registry for easy lookup
221
+ SCHEDULER_REGISTRY = {
222
+ SchedulerType.WARMUP_STABLE_DECAY: get_warmup_stable_decay_schedule,
223
+ SchedulerType.COSINE_WITH_WARMUP: get_enhanced_cosine_schedule_with_warmup,
224
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_restarts_schedule,
225
+ }
226
+
227
+
228
+ def get_scheduler(
229
+ scheduler_type: str | SchedulerType,
230
+ optimizer: Optimizer,
231
+ num_warmup_steps: int,
232
+ num_training_steps: int,
233
+ scheduler_kwargs: Optional[dict] = None,
234
+ ):
235
+ """
236
+ Unified interface to create learning rate schedulers.
237
+
238
+ Args:
239
+ scheduler_type: Type of scheduler to create
240
+ optimizer: The optimizer to schedule
241
+ num_warmup_steps: Number of warmup steps
242
+ num_training_steps: Total training steps
243
+ scheduler_kwargs: Additional scheduler-specific parameters
244
+
245
+ Returns:
246
+ Configured learning rate scheduler
247
+ """
248
+ if isinstance(scheduler_type, str):
249
+ scheduler_type = SchedulerType(scheduler_type)
250
+
251
+ if scheduler_kwargs is None:
252
+ scheduler_kwargs = {}
253
+
254
+ if scheduler_type not in SCHEDULER_REGISTRY:
255
+ raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
256
+
257
+ scheduler_func = SCHEDULER_REGISTRY[scheduler_type]
258
+ return scheduler_func(
259
+ optimizer=optimizer,
260
+ num_warmup_steps=num_warmup_steps,
261
+ num_training_steps=num_training_steps,
262
+ **scheduler_kwargs,
263
+ )
264
+
265
+
266
+ class WarmupStableDecayScheduler:
267
+ """
268
+ Alternative implementation as a standalone scheduler class.
269
+
270
+ This provides more flexibility and better state management for
271
+ complex training scenarios with checkpointing.
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ optimizer: Optimizer,
277
+ num_warmup_steps: int,
278
+ num_stable_steps: int,
279
+ total_steps: int,
280
+ min_lr_ratio: float = 0.01,
281
+ decay_type: str = "cosine",
282
+ verbose: bool = False,
283
+ ):
284
+ self.optimizer = optimizer
285
+ self.num_warmup_steps = num_warmup_steps
286
+ self.num_stable_steps = num_stable_steps
287
+ self.total_steps = total_steps
288
+ self.min_lr_ratio = min_lr_ratio
289
+ self.decay_type = decay_type
290
+ self.verbose = verbose
291
+
292
+ # Store initial learning rates
293
+ self.base_lrs = [group["lr"] for group in optimizer.param_groups]
294
+ self.current_step = 0
295
+
296
+ def get_lr_factor(self, step: int) -> float:
297
+ """Calculate the learning rate multiplication factor for given step."""
298
+ if step < self.num_warmup_steps:
299
+ # Warmup phase
300
+ return step / max(1, self.num_warmup_steps)
301
+ elif step < self.num_warmup_steps + self.num_stable_steps:
302
+ # Stable phase
303
+ return 1.0
304
+ else:
305
+ # Decay phase
306
+ decay_steps = (
307
+ self.total_steps - self.num_warmup_steps - self.num_stable_steps
308
+ )
309
+ if decay_steps <= 0:
310
+ return max(self.min_lr_ratio, 1.0)
311
+
312
+ progress = (
313
+ step - self.num_warmup_steps - self.num_stable_steps
314
+ ) / decay_steps
315
+ progress = min(progress, 1.0)
316
+
317
+ if self.decay_type == "cosine":
318
+ decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
319
+ elif self.decay_type == "linear":
320
+ decay_factor = 1.0 - progress
321
+ else:
322
+ raise ValueError(f"Unknown decay_type: {self.decay_type}")
323
+
324
+ return max(self.min_lr_ratio, decay_factor)
325
+
326
+ def step(self):
327
+ """Update learning rates for all parameter groups."""
328
+ lr_factor = self.get_lr_factor(self.current_step)
329
+
330
+ for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
331
+ param_group["lr"] = base_lr * lr_factor
332
+
333
+ if self.verbose and self.current_step % 1000 == 0:
334
+ phase = self.get_phase()
335
+ print(
336
+ f"Step {self.current_step}: LR factor = {lr_factor:.6f}, Phase = {phase}"
337
+ )
338
+
339
+ self.current_step += 1
340
+
341
+ def get_phase(self) -> str:
342
+ """Get current training phase."""
343
+ if self.current_step < self.num_warmup_steps:
344
+ return "warmup"
345
+ elif self.current_step < self.num_warmup_steps + self.num_stable_steps:
346
+ return "stable"
347
+ else:
348
+ return "decay"
349
+
350
+ def state_dict(self) -> dict:
351
+ """Return scheduler state for checkpointing."""
352
+ return {
353
+ "current_step": self.current_step,
354
+ "base_lrs": self.base_lrs,
355
+ }
356
+
357
+ def load_state_dict(self, state_dict: dict):
358
+ """Load scheduler state from checkpoint."""
359
+ self.current_step = state_dict["current_step"]
360
+ self.base_lrs = state_dict["base_lrs"]
src/plotting/__init__.py ADDED
File without changes
src/plotting/gift_eval_utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional, Tuple
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from gluonts.model.forecast import QuantileForecast
7
+
8
+ from src.data.frequency import parse_frequency
9
+ from src.plotting.plot_multivariate_timeseries import (
10
+ plot_multivariate_timeseries,
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _prepare_data_for_plotting(
17
+ input_data: dict, label_data: dict, max_context_length: int
18
+ ):
19
+ history_values = np.asarray(input_data["target"], dtype=np.float32)
20
+ future_values = np.asarray(label_data["target"], dtype=np.float32)
21
+ start_period = input_data["start"]
22
+
23
+ def ensure_time_first(arr: np.ndarray) -> np.ndarray:
24
+ if arr.ndim == 1:
25
+ return arr.reshape(-1, 1)
26
+ elif arr.ndim == 2:
27
+ if arr.shape[0] < arr.shape[1]:
28
+ return arr.T
29
+ return arr
30
+ else:
31
+ return arr.reshape(arr.shape[-1], -1).T
32
+
33
+ history_values = ensure_time_first(history_values)
34
+ future_values = ensure_time_first(future_values)
35
+
36
+ if max_context_length is not None and history_values.shape[0] > max_context_length:
37
+ history_values = history_values[-max_context_length:]
38
+
39
+ # Convert Period to Timestamp if needed
40
+ start_timestamp = (
41
+ start_period.to_timestamp()
42
+ if hasattr(start_period, "to_timestamp")
43
+ else pd.Timestamp(start_period)
44
+ )
45
+ return history_values, future_values, start_timestamp
46
+
47
+
48
+ def _extract_quantile_predictions(
49
+ forecast,
50
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
51
+ def ensure_2d_time_first(arr):
52
+ if arr is None:
53
+ return None
54
+ arr = np.asarray(arr)
55
+ if arr.ndim == 1:
56
+ return arr.reshape(-1, 1)
57
+ elif arr.ndim == 2:
58
+ return arr
59
+ else:
60
+ return arr.reshape(arr.shape[0], -1)
61
+
62
+ if isinstance(forecast, QuantileForecast):
63
+ try:
64
+ median_pred = forecast.quantile(0.5)
65
+ try:
66
+ lower_bound = forecast.quantile(0.1)
67
+ upper_bound = forecast.quantile(0.9)
68
+ except (KeyError, ValueError):
69
+ lower_bound = None
70
+ upper_bound = None
71
+ median_pred = ensure_2d_time_first(median_pred)
72
+ lower_bound = ensure_2d_time_first(lower_bound)
73
+ upper_bound = ensure_2d_time_first(upper_bound)
74
+ return median_pred, lower_bound, upper_bound
75
+ except Exception:
76
+ try:
77
+ median_pred = forecast.quantile(0.5)
78
+ median_pred = ensure_2d_time_first(median_pred)
79
+ return median_pred, None, None
80
+ except Exception:
81
+ return None, None, None
82
+ else:
83
+ try:
84
+ samples = forecast.samples
85
+ if samples.ndim == 1:
86
+ median_pred = samples
87
+ elif samples.ndim == 2:
88
+ if samples.shape[0] == 1:
89
+ median_pred = samples[0]
90
+ else:
91
+ median_pred = np.median(samples, axis=0)
92
+ elif samples.ndim == 3:
93
+ median_pred = np.median(samples, axis=0)
94
+ else:
95
+ median_pred = samples[0] if len(samples) > 0 else samples
96
+ median_pred = ensure_2d_time_first(median_pred)
97
+ return median_pred, None, None
98
+ except Exception:
99
+ return None, None, None
100
+
101
+
102
+ def _create_plot(
103
+ input_data: dict,
104
+ label_data: dict,
105
+ forecast,
106
+ dataset_full_name: str,
107
+ dataset_freq: str,
108
+ max_context_length: int,
109
+ title: Optional[str] = None,
110
+ ):
111
+ try:
112
+ history_values, future_values, start_timestamp = _prepare_data_for_plotting(
113
+ input_data, label_data, max_context_length
114
+ )
115
+ median_pred, lower_bound, upper_bound = _extract_quantile_predictions(forecast)
116
+ if median_pred is None:
117
+ logger.warning(f"Could not extract predictions for {dataset_full_name}")
118
+ return None
119
+
120
+ def ensure_compatible_shape(pred_arr, target_arr):
121
+ if pred_arr is None:
122
+ return None
123
+ pred_arr = np.asarray(pred_arr)
124
+ target_arr = np.asarray(target_arr)
125
+ if pred_arr.ndim == 1:
126
+ pred_arr = pred_arr.reshape(-1, 1)
127
+ if target_arr.ndim == 1:
128
+ target_arr = target_arr.reshape(-1, 1)
129
+ if pred_arr.shape != target_arr.shape:
130
+ if pred_arr.shape[0] == target_arr.shape[0]:
131
+ if pred_arr.shape[1] == 1 and target_arr.shape[1] > 1:
132
+ pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
133
+ elif pred_arr.shape[1] > 1 and target_arr.shape[1] == 1:
134
+ pred_arr = pred_arr[:, :1]
135
+ elif pred_arr.shape[1] == target_arr.shape[1]:
136
+ min_time = min(pred_arr.shape[0], target_arr.shape[0])
137
+ pred_arr = pred_arr[:min_time]
138
+ else:
139
+ if pred_arr.T.shape == target_arr.shape:
140
+ pred_arr = pred_arr.T
141
+ else:
142
+ if pred_arr.size >= target_arr.shape[0]:
143
+ pred_arr = pred_arr.flatten()[
144
+ : target_arr.shape[0]
145
+ ].reshape(-1, 1)
146
+ if target_arr.shape[1] > 1:
147
+ pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
148
+ return pred_arr
149
+
150
+ median_pred = ensure_compatible_shape(median_pred, future_values)
151
+ lower_bound = ensure_compatible_shape(lower_bound, future_values)
152
+ upper_bound = ensure_compatible_shape(upper_bound, future_values)
153
+
154
+ title = title or f"GIFT-Eval: {dataset_full_name}"
155
+ frequency = parse_frequency(dataset_freq)
156
+ fig = plot_multivariate_timeseries(
157
+ history_values=history_values,
158
+ future_values=future_values,
159
+ predicted_values=median_pred,
160
+ lower_bound=lower_bound,
161
+ upper_bound=upper_bound,
162
+ start=start_timestamp,
163
+ frequency=frequency,
164
+ title=title,
165
+ show=False,
166
+ )
167
+ return fig
168
+ except Exception as e:
169
+ logger.warning(f"Failed to create plot for {dataset_full_name}: {e}")
170
+ return None
171
+
172
+
173
+ def create_plots_for_dataset(
174
+ forecasts: List,
175
+ test_data,
176
+ dataset_metadata,
177
+ max_plots: int,
178
+ max_context_length: int,
179
+ ) -> List[Tuple[object, str]]:
180
+ input_data_list = list(test_data.input)
181
+ label_data_list = list(test_data.label)
182
+ num_plots = min(len(forecasts), max_plots)
183
+ logger.info(
184
+ f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}"
185
+ )
186
+
187
+ figures_with_names: List[Tuple[object, str]] = []
188
+ for i in range(num_plots):
189
+ try:
190
+ forecast = forecasts[i]
191
+ input_data = input_data_list[i]
192
+ label_data = label_data_list[i]
193
+ title = (
194
+ f"GIFT-Eval: {dataset_metadata.full_name} - Window {i + 1}/{num_plots}"
195
+ if hasattr(dataset_metadata, "full_name")
196
+ else f"Window {i + 1}/{num_plots}"
197
+ )
198
+ fig = _create_plot(
199
+ input_data=input_data,
200
+ label_data=label_data,
201
+ forecast=forecast,
202
+ dataset_full_name=getattr(dataset_metadata, "full_name", "dataset"),
203
+ dataset_freq=getattr(dataset_metadata, "freq", "D"),
204
+ max_context_length=max_context_length,
205
+ title=title,
206
+ )
207
+ if fig is not None:
208
+ filename = (
209
+ f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png"
210
+ )
211
+ figures_with_names.append((fig, filename))
212
+ except Exception as e:
213
+ logger.warning(f"Error creating plot for window {i + 1}: {e}")
214
+ continue
215
+ return figures_with_names
src/plotting/plot_timeseries.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torchmetrics
9
+ from matplotlib.figure import Figure
10
+
11
+ from src.data.containers import BatchTimeSeriesContainer
12
+ from src.data.frequency import Frequency
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def calculate_smape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
18
+ """Calculate Symmetric Mean Absolute Percentage Error (SMAPE)."""
19
+ pred_tensor = torch.from_numpy(y_pred).float()
20
+ true_tensor = torch.from_numpy(y_true).float()
21
+ return torchmetrics.SymmetricMeanAbsolutePercentageError()(
22
+ pred_tensor, true_tensor
23
+ ).item()
24
+
25
+
26
+ def _create_date_ranges(
27
+ start: Optional[Union[np.datetime64, pd.Timestamp]],
28
+ frequency: Optional[Union[Frequency, str]],
29
+ history_length: int,
30
+ prediction_length: int,
31
+ ) -> Tuple[pd.DatetimeIndex, pd.DatetimeIndex]:
32
+ """Create date ranges for history and future periods."""
33
+ if start is not None and frequency is not None:
34
+ start_timestamp = pd.Timestamp(start)
35
+ pandas_freq = frequency.to_pandas_freq(for_date_range=True)
36
+
37
+ history_dates = pd.date_range(
38
+ start=start_timestamp, periods=history_length, freq=pandas_freq
39
+ )
40
+
41
+ if prediction_length > 0:
42
+ next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(
43
+ pandas_freq
44
+ )
45
+ future_dates = pd.date_range(
46
+ start=next_timestamp, periods=prediction_length, freq=pandas_freq
47
+ )
48
+ else:
49
+ future_dates = pd.DatetimeIndex([])
50
+ else:
51
+ # Fallback to default daily frequency
52
+ history_dates = pd.date_range(
53
+ end=pd.Timestamp.now(), periods=history_length, freq="D"
54
+ )
55
+
56
+ if prediction_length > 0:
57
+ future_dates = pd.date_range(
58
+ start=history_dates[-1] + pd.Timedelta(days=1),
59
+ periods=prediction_length,
60
+ freq="D",
61
+ )
62
+ else:
63
+ future_dates = pd.DatetimeIndex([])
64
+
65
+ return history_dates, future_dates
66
+
67
+
68
+ def _plot_single_channel(
69
+ ax: plt.Axes,
70
+ channel_idx: int,
71
+ history_dates: pd.DatetimeIndex,
72
+ future_dates: pd.DatetimeIndex,
73
+ history_values: np.ndarray,
74
+ future_values: Optional[np.ndarray] = None,
75
+ predicted_values: Optional[np.ndarray] = None,
76
+ lower_bound: Optional[np.ndarray] = None,
77
+ upper_bound: Optional[np.ndarray] = None,
78
+ ) -> None:
79
+ """Plot a single channel's time series data."""
80
+ # Plot history
81
+ ax.plot(
82
+ history_dates, history_values[:, channel_idx], color="black", label="History"
83
+ )
84
+
85
+ # Plot ground truth future
86
+ if future_values is not None:
87
+ ax.plot(
88
+ future_dates,
89
+ future_values[:, channel_idx],
90
+ color="blue",
91
+ label="Ground Truth",
92
+ )
93
+
94
+ # Plot predictions
95
+ if predicted_values is not None:
96
+ ax.plot(
97
+ future_dates,
98
+ predicted_values[:, channel_idx],
99
+ color="orange",
100
+ linestyle="--",
101
+ label="Prediction (Median)",
102
+ )
103
+
104
+ # Plot uncertainty band
105
+ if lower_bound is not None and upper_bound is not None:
106
+ ax.fill_between(
107
+ future_dates,
108
+ lower_bound[:, channel_idx],
109
+ upper_bound[:, channel_idx],
110
+ color="orange",
111
+ alpha=0.2,
112
+ label="Uncertainty Band",
113
+ )
114
+
115
+ ax.set_title(f"Channel {channel_idx + 1}")
116
+ ax.grid(True, which="both", linestyle="--", linewidth=0.5)
117
+
118
+
119
+ def _setup_figure(num_channels: int) -> Tuple[Figure, List[plt.Axes]]:
120
+ """Create and configure the matplotlib figure and axes."""
121
+ fig, axes = plt.subplots(
122
+ num_channels, 1, figsize=(15, 3 * num_channels), sharex=True
123
+ )
124
+ if num_channels == 1:
125
+ axes = [axes]
126
+ return fig, axes
127
+
128
+
129
+ def _finalize_plot(
130
+ fig: Figure,
131
+ axes: List[plt.Axes],
132
+ title: Optional[str] = None,
133
+ smape_value: Optional[float] = None,
134
+ output_file: Optional[str] = None,
135
+ show: bool = True,
136
+ ) -> None:
137
+ """Add legend, title, and save/show the plot."""
138
+ # Create legend from first axis
139
+ handles, labels = axes[0].get_legend_handles_labels()
140
+ fig.legend(handles, labels, loc="upper right")
141
+
142
+ # Set title with optional SMAPE
143
+ if title:
144
+ if smape_value is not None:
145
+ title = f"{title} | SMAPE: {smape_value:.4f}"
146
+ fig.suptitle(title, fontsize=16)
147
+
148
+ # Adjust layout
149
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95] if title else None)
150
+
151
+ # Save and/or show
152
+ if output_file:
153
+ plt.savefig(output_file, dpi=300)
154
+ if show:
155
+ plt.show()
156
+ else:
157
+ plt.close(fig)
158
+
159
+
160
+ def plot_multivariate_timeseries(
161
+ history_values: np.ndarray,
162
+ future_values: Optional[np.ndarray] = None,
163
+ predicted_values: Optional[np.ndarray] = None,
164
+ start: Optional[Union[np.datetime64, pd.Timestamp]] = None,
165
+ frequency: Optional[Union[Frequency, str]] = None,
166
+ title: Optional[str] = None,
167
+ output_file: Optional[str] = None,
168
+ show: bool = True,
169
+ lower_bound: Optional[np.ndarray] = None,
170
+ upper_bound: Optional[np.ndarray] = None,
171
+ ) -> Figure:
172
+ """Plot a multivariate time series with history, future, predictions, and uncertainty bands."""
173
+ # Calculate SMAPE if both predicted and true values are available
174
+ smape_value = None
175
+ if predicted_values is not None and future_values is not None:
176
+ try:
177
+ smape_value = calculate_smape(future_values, predicted_values)
178
+ except Exception as e:
179
+ logger.warning(f"Failed to calculate SMAPE: {str(e)}")
180
+
181
+ # Extract dimensions
182
+ num_channels = history_values.shape[1]
183
+ history_length = history_values.shape[0]
184
+ prediction_length = (
185
+ predicted_values.shape[0]
186
+ if predicted_values is not None
187
+ else (future_values.shape[0] if future_values is not None else 0)
188
+ )
189
+
190
+ # Create date ranges
191
+ history_dates, future_dates = _create_date_ranges(
192
+ start, frequency, history_length, prediction_length
193
+ )
194
+
195
+ # Setup figure
196
+ fig, axes = _setup_figure(num_channels)
197
+
198
+ # Plot each channel
199
+ for i in range(num_channels):
200
+ _plot_single_channel(
201
+ ax=axes[i],
202
+ channel_idx=i,
203
+ history_dates=history_dates,
204
+ future_dates=future_dates,
205
+ history_values=history_values,
206
+ future_values=future_values,
207
+ predicted_values=predicted_values,
208
+ lower_bound=lower_bound,
209
+ upper_bound=upper_bound,
210
+ )
211
+
212
+ # Finalize plot
213
+ _finalize_plot(fig, axes, title, smape_value, output_file, show)
214
+
215
+ return fig
216
+
217
+
218
+ def _extract_quantile_predictions(
219
+ predicted_values: np.ndarray,
220
+ model_quantiles: List[float],
221
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
222
+ """Extract median, lower, and upper bound predictions from quantile output."""
223
+ try:
224
+ median_idx = model_quantiles.index(0.5)
225
+ lower_idx = model_quantiles.index(0.1)
226
+ upper_idx = model_quantiles.index(0.9)
227
+
228
+ median_preds = predicted_values[..., median_idx]
229
+ lower_bound = predicted_values[..., lower_idx]
230
+ upper_bound = predicted_values[..., upper_idx]
231
+
232
+ return median_preds, lower_bound, upper_bound
233
+ except (ValueError, IndexError):
234
+ logger.warning(
235
+ "Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles."
236
+ )
237
+ median_preds = predicted_values[..., predicted_values.shape[-1] // 2]
238
+ return median_preds, None, None
239
+
240
+
241
+ def plot_from_container(
242
+ batch: BatchTimeSeriesContainer,
243
+ sample_idx: int,
244
+ predicted_values: Optional[np.ndarray] = None,
245
+ model_quantiles: Optional[List[float]] = None,
246
+ title: Optional[str] = None,
247
+ output_file: Optional[str] = None,
248
+ show: bool = True,
249
+ ) -> Figure:
250
+ """Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling."""
251
+ # Extract data for the specific sample
252
+ history_values = batch.history_values[sample_idx].cpu().numpy()
253
+ future_values = batch.future_values[sample_idx].cpu().numpy()
254
+
255
+ # Process predictions
256
+ if predicted_values is not None:
257
+ # Handle batch vs single sample predictions
258
+ if predicted_values.ndim >= 3 or (
259
+ predicted_values.ndim == 2
260
+ and predicted_values.shape[0] > future_values.shape[0]
261
+ ):
262
+ sample_preds = predicted_values[sample_idx]
263
+ else:
264
+ sample_preds = predicted_values
265
+
266
+ # Extract quantile information if available
267
+ if model_quantiles:
268
+ median_preds, lower_bound, upper_bound = _extract_quantile_predictions(
269
+ sample_preds, model_quantiles
270
+ )
271
+ else:
272
+ median_preds = sample_preds
273
+ lower_bound = None
274
+ upper_bound = None
275
+ else:
276
+ median_preds = None
277
+ lower_bound = None
278
+ upper_bound = None
279
+
280
+ # Create the plot
281
+ return plot_multivariate_timeseries(
282
+ history_values=history_values,
283
+ future_values=future_values,
284
+ predicted_values=median_preds,
285
+ start=batch.start[sample_idx],
286
+ frequency=batch.frequency[sample_idx],
287
+ title=title,
288
+ output_file=output_file,
289
+ show=show,
290
+ lower_bound=lower_bound,
291
+ upper_bound=upper_bound,
292
+ )
src/synthetic_generation/__init__.py ADDED
File without changes
src/synthetic_generation/abstract_classes.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from src.data.containers import TimeSeriesContainer
8
+ from src.data.frequency import (
9
+ select_safe_random_frequency,
10
+ select_safe_start_date,
11
+ )
12
+ from src.synthetic_generation.generator_params import GeneratorParams
13
+
14
+
15
+ class AbstractTimeSeriesGenerator(ABC):
16
+ """
17
+ Abstract base class for synthetic time series generators.
18
+ """
19
+
20
+ @abstractmethod
21
+ def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
22
+ """
23
+ Generate synthetic time series data.
24
+
25
+ Parameters
26
+ ----------
27
+ random_seed : int, optional
28
+ Random seed for reproducibility.
29
+
30
+ Returns
31
+ -------
32
+ np.ndarray
33
+ Time series values of shape (length,) for univariate or
34
+ (length, num_channels) for multivariate time series.
35
+ """
36
+ pass
37
+
38
+
39
+ class GeneratorWrapper:
40
+ """
41
+ Unified base class for all generator wrappers, using a GeneratorParams dataclass
42
+ for configuration. Provides parameter sampling, validation, and batch formatting utilities.
43
+ """
44
+
45
+ def __init__(self, params: GeneratorParams):
46
+ """
47
+ Initialize the GeneratorWrapper with a GeneratorParams dataclass.
48
+
49
+ Parameters
50
+ ----------
51
+ params : GeneratorParams
52
+ Dataclass instance containing all generator configuration parameters.
53
+ """
54
+ self.params = params
55
+ self._set_random_seeds(self.params.global_seed)
56
+
57
+ def _set_random_seeds(self, seed: int) -> None:
58
+ # For parameter sampling, we want diversity across batches even with similar seeds
59
+ # Use a hash of the generator class name to ensure different generators get different parameter sequences
60
+ param_seed = seed + hash(self.__class__.__name__) % 2**31
61
+ self.rng = np.random.default_rng(param_seed)
62
+
63
+ # Set global numpy and torch seeds for deterministic behavior in underlying generators
64
+ np.random.seed(seed)
65
+ torch.manual_seed(seed)
66
+
67
+ def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
68
+ """
69
+ Sample parameters with total_length fixed and history_length calculated.
70
+
71
+ Returns
72
+ -------
73
+ Dict[str, Any]
74
+ Dictionary containing sampled parameter values where
75
+ history_length = total_length - future_length.
76
+ """
77
+
78
+ # Select a suitable frequency based on the total length
79
+ frequency = [
80
+ select_safe_random_frequency(self.params.length, self.rng)
81
+ for _ in range(batch_size)
82
+ ]
83
+ start = [
84
+ select_safe_start_date(self.params.length, frequency[i], self.rng)
85
+ for i in range(batch_size)
86
+ ]
87
+
88
+ return {
89
+ "frequency": frequency,
90
+ "start": start,
91
+ }
92
+
93
+ @abstractmethod
94
+ def generate_batch(
95
+ self, batch_size: int, seed: Optional[int] = None, **kwargs
96
+ ) -> TimeSeriesContainer:
97
+ raise NotImplementedError("Subclasses must implement generate_batch()")
src/synthetic_generation/anomalies/anomaly_generator.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Set
2
+
3
+ import numpy as np
4
+
5
+ from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
6
+ from src.synthetic_generation.generator_params import (
7
+ AnomalyGeneratorParams,
8
+ AnomalyType,
9
+ MagnitudePattern,
10
+ )
11
+
12
+
13
+ class AnomalyGenerator(AbstractTimeSeriesGenerator):
14
+ """
15
+ Generator for synthetic time series with realistic spike anomalies.
16
+
17
+ Creates clean constant baseline signals with periodic spike patterns that
18
+ resemble real-world time series behavior, including clustering and magnitude patterns.
19
+ """
20
+
21
+ def __init__(self, params: AnomalyGeneratorParams):
22
+ """
23
+ Initialize the AnomalyGenerator.
24
+
25
+ Parameters
26
+ ----------
27
+ params : AnomalyGeneratorParams
28
+ Configuration parameters for anomaly generation.
29
+ """
30
+ self.params = params
31
+
32
+ def _determine_spike_direction(self) -> AnomalyType:
33
+ """
34
+ Determine if this series will have only up or only down spikes.
35
+
36
+ Returns
37
+ -------
38
+ AnomalyType
39
+ Either SPIKE_UP or SPIKE_DOWN for the entire series.
40
+ """
41
+ if np.random.random() < self.params.spike_direction_probability:
42
+ return AnomalyType.SPIKE_UP
43
+ else:
44
+ return AnomalyType.SPIKE_DOWN
45
+
46
+ def _generate_spike_positions(self) -> List[List[int]]:
47
+ """
48
+ Generate spike positions:
49
+ - Always create uniformly spaced single spikes (base schedule)
50
+ - With 25% probability: add clusters (1-3 extra spikes) near a fraction of base spikes
51
+ - With 25% probability: add single random spikes across the series
52
+
53
+ Returns
54
+ -------
55
+ List[List[int]]
56
+ List of spike events, where each event is a list of positions
57
+ (single spike = [pos], cluster = [pos, pos+offset, ...]).
58
+ """
59
+ # Base uniform schedule (no jitter/variance)
60
+ base_period = np.random.randint(*self.params.base_period_range)
61
+ start_position = base_period // 2
62
+ base_positions = list(range(start_position, self.params.length, base_period))
63
+
64
+ # Start with single-spike events at base positions
65
+ spike_events: List[List[int]] = [[pos] for pos in base_positions]
66
+
67
+ if not base_positions:
68
+ return spike_events
69
+
70
+ # Decide series type
71
+ series_draw = np.random.random()
72
+
73
+ # 25%: augment with clusters near some base spikes
74
+ if series_draw < self.params.cluster_series_probability:
75
+ num_base_events = len(base_positions)
76
+ num_to_augment = max(
77
+ 1, int(round(self.params.cluster_event_fraction * num_base_events))
78
+ )
79
+ num_to_augment = min(num_to_augment, num_base_events)
80
+
81
+ chosen_indices = (
82
+ np.random.choice(num_base_events, size=num_to_augment, replace=False)
83
+ if num_to_augment > 0
84
+ else np.array([], dtype=int)
85
+ )
86
+
87
+ for idx in chosen_indices:
88
+ base_pos = base_positions[int(idx)]
89
+ # Number of additional spikes (1..3) per selected event
90
+ num_additional = np.random.randint(
91
+ *self.params.cluster_additional_spikes_range
92
+ )
93
+ if num_additional <= 0:
94
+ continue
95
+
96
+ # Draw offsets around base spike and exclude zero to avoid duplicates
97
+ offsets = np.random.randint(
98
+ self.params.cluster_offset_range[0],
99
+ self.params.cluster_offset_range[1],
100
+ size=num_additional,
101
+ )
102
+ offsets = [int(off) for off in offsets if off != 0]
103
+
104
+ cluster_positions: Set[int] = set([base_pos])
105
+ for off in offsets:
106
+ pos = base_pos + off
107
+ if 0 <= pos < self.params.length:
108
+ cluster_positions.add(pos)
109
+
110
+ spike_events[int(idx)] = sorted(cluster_positions)
111
+
112
+ # Next 25%: add random single spikes across the series
113
+ elif series_draw < (
114
+ self.params.cluster_series_probability
115
+ + self.params.random_series_probability
116
+ ):
117
+ num_base_events = len(base_positions)
118
+ num_random = int(
119
+ round(self.params.random_spike_fraction_of_base * num_base_events)
120
+ )
121
+ if num_random > 0:
122
+ all_indices = np.arange(self.params.length)
123
+ base_array = np.array(base_positions, dtype=int)
124
+ candidates = np.setdiff1d(all_indices, base_array, assume_unique=False)
125
+ if candidates.size > 0:
126
+ choose_n = min(num_random, candidates.size)
127
+ rand_positions = np.random.choice(
128
+ candidates, size=choose_n, replace=False
129
+ )
130
+ for pos in rand_positions:
131
+ spike_events.append([int(pos)])
132
+
133
+ # Else: 50% clean series (uniform singles only)
134
+
135
+ return spike_events
136
+
137
+ def _generate_spike_magnitudes(self, total_spikes: int) -> np.ndarray:
138
+ """
139
+ Generate spike magnitudes based on the configured pattern.
140
+
141
+ Parameters
142
+ ----------
143
+ total_spikes : int
144
+ Total number of individual spikes to generate magnitudes for.
145
+
146
+ Returns
147
+ -------
148
+ np.ndarray
149
+ Array of spike magnitudes.
150
+ """
151
+ base_magnitude = np.random.uniform(*self.params.base_magnitude_range)
152
+ magnitudes = np.zeros(total_spikes)
153
+
154
+ if self.params.magnitude_pattern == MagnitudePattern.CONSTANT:
155
+ # All spikes have similar magnitude with small noise
156
+ magnitudes = np.full(total_spikes, base_magnitude)
157
+ noise = np.random.normal(
158
+ 0, self.params.magnitude_noise * base_magnitude, total_spikes
159
+ )
160
+ magnitudes += noise
161
+
162
+ elif self.params.magnitude_pattern == MagnitudePattern.INCREASING:
163
+ # Magnitude increases over time
164
+ trend = np.linspace(
165
+ 0,
166
+ self.params.magnitude_trend_strength * base_magnitude * total_spikes,
167
+ total_spikes,
168
+ )
169
+ magnitudes = base_magnitude + trend
170
+
171
+ elif self.params.magnitude_pattern == MagnitudePattern.DECREASING:
172
+ # Magnitude decreases over time
173
+ trend = np.linspace(
174
+ 0,
175
+ -self.params.magnitude_trend_strength * base_magnitude * total_spikes,
176
+ total_spikes,
177
+ )
178
+ magnitudes = base_magnitude + trend
179
+
180
+ elif self.params.magnitude_pattern == MagnitudePattern.CYCLICAL:
181
+ # Cyclical magnitude pattern
182
+ cycle_length = int(total_spikes * self.params.cyclical_period_ratio)
183
+ if cycle_length == 0:
184
+ cycle_length = max(1, total_spikes // 4)
185
+
186
+ phase = np.linspace(
187
+ 0, 2 * np.pi * total_spikes / cycle_length, total_spikes
188
+ )
189
+ cyclical_component = 0.3 * base_magnitude * np.sin(phase)
190
+ magnitudes = base_magnitude + cyclical_component
191
+
192
+ elif self.params.magnitude_pattern == MagnitudePattern.RANDOM_BOUNDED:
193
+ # Random with correlation between consecutive spikes
194
+ magnitudes[0] = base_magnitude
195
+
196
+ for i in range(1, total_spikes):
197
+ # Correlated random walk
198
+ prev_magnitude = magnitudes[i - 1]
199
+ random_component = np.random.normal(0, 0.2 * base_magnitude)
200
+
201
+ magnitudes[i] = (
202
+ self.params.magnitude_correlation * prev_magnitude
203
+ + (1 - self.params.magnitude_correlation) * base_magnitude
204
+ + random_component
205
+ )
206
+
207
+ # Add noise to all patterns
208
+ noise = np.random.normal(
209
+ 0, self.params.magnitude_noise * base_magnitude, total_spikes
210
+ )
211
+ magnitudes += noise
212
+
213
+ # Ensure magnitudes are positive and within reasonable bounds
214
+ min_magnitude = 0.1 * base_magnitude
215
+ max_magnitude = 3.0 * base_magnitude
216
+ magnitudes = np.clip(magnitudes, min_magnitude, max_magnitude)
217
+
218
+ return magnitudes
219
+
220
+ def _inject_spike_anomalies(
221
+ self, signal: np.ndarray, spike_direction: AnomalyType
222
+ ) -> np.ndarray:
223
+ """
224
+ Inject spike anomalies into the clean signal using realistic patterns.
225
+
226
+ Parameters
227
+ ----------
228
+ signal : np.ndarray
229
+ Clean baseline signal to inject spikes into.
230
+ spike_direction : AnomalyType
231
+ Direction of spikes for this series (all up or all down).
232
+
233
+ Returns
234
+ -------
235
+ np.ndarray
236
+ Signal with injected spike anomalies.
237
+ """
238
+ anomalous_signal = signal.copy()
239
+
240
+ # Generate spike positions based on pattern
241
+ spike_events = self._generate_spike_positions()
242
+
243
+ # Flatten spike events to get total number of individual spikes
244
+ all_positions = []
245
+ for event in spike_events:
246
+ all_positions.extend(event)
247
+
248
+ if not all_positions:
249
+ return anomalous_signal
250
+
251
+ # Generate magnitudes for all spikes
252
+ magnitudes = self._generate_spike_magnitudes(len(all_positions))
253
+
254
+ # Inject spikes
255
+ for i, position in enumerate(all_positions):
256
+ if position < len(anomalous_signal):
257
+ magnitude = magnitudes[i]
258
+
259
+ if spike_direction == AnomalyType.SPIKE_UP:
260
+ anomalous_signal[position] += magnitude
261
+ else: # SPIKE_DOWN
262
+ anomalous_signal[position] -= magnitude
263
+
264
+ return anomalous_signal
265
+
266
+ def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
267
+ """
268
+ Generate a synthetic time series with realistic spike anomalies.
269
+
270
+ Parameters
271
+ ----------
272
+ random_seed : int, optional
273
+ Random seed for reproducibility.
274
+
275
+ Returns
276
+ -------
277
+ np.ndarray
278
+ Generated time series of shape (length,) - clean baseline with periodic spikes.
279
+ """
280
+ if random_seed is not None:
281
+ np.random.seed(random_seed)
282
+
283
+ # Generate clean baseline signal (constant level)
284
+ baseline_level = np.random.uniform(*self.params.base_level_range)
285
+ signal = np.full(self.params.length, baseline_level)
286
+
287
+ # Determine spike direction for this series (all up or all down)
288
+ spike_direction = self._determine_spike_direction()
289
+
290
+ # Inject spike anomalies with realistic patterns
291
+ anomalous_signal = self._inject_spike_anomalies(signal, spike_direction)
292
+
293
+ return anomalous_signal
src/synthetic_generation/anomalies/anomaly_generator_wrapper.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+
5
+ from src.data.containers import TimeSeriesContainer
6
+ from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
+ from src.synthetic_generation.anomalies.anomaly_generator import AnomalyGenerator
8
+ from src.synthetic_generation.generator_params import AnomalyGeneratorParams
9
+
10
+
11
+ class AnomalyGeneratorWrapper(GeneratorWrapper):
12
+ """
13
+ Wrapper for AnomalyGenerator that handles batch generation and formatting.
14
+ """
15
+
16
+ def __init__(self, params: AnomalyGeneratorParams):
17
+ """
18
+ Initialize the AnomalyGeneratorWrapper.
19
+
20
+ Parameters
21
+ ----------
22
+ params : AnomalyGeneratorParams
23
+ Parameters for the anomaly generator.
24
+ """
25
+ super().__init__(params)
26
+ self.generator = AnomalyGenerator(params)
27
+
28
+ def generate_batch(
29
+ self, batch_size: int, seed: Optional[int] = None
30
+ ) -> TimeSeriesContainer:
31
+ """
32
+ Generate a batch of anomaly time series.
33
+
34
+ Parameters
35
+ ----------
36
+ batch_size : int
37
+ Number of time series to generate.
38
+ seed : int, optional
39
+ Random seed for reproducibility.
40
+
41
+ Returns
42
+ -------
43
+ TimeSeriesContainer
44
+ TimeSeriesContainer containing the generated time series.
45
+ """
46
+ if seed is not None:
47
+ self._set_random_seeds(seed)
48
+
49
+ # Sample parameters for the batch
50
+ sampled_params = self._sample_parameters(batch_size)
51
+
52
+ # Generate time series
53
+ values = []
54
+ for i in range(batch_size):
55
+ # Use a different seed for each series in the batch
56
+ series_seed = (seed + i) if seed is not None else None
57
+ series = self.generator.generate_time_series(series_seed)
58
+ values.append(series)
59
+
60
+ return TimeSeriesContainer(
61
+ values=np.array(values),
62
+ start=sampled_params["start"],
63
+ frequency=sampled_params["frequency"],
64
+ )
src/synthetic_generation/audio_generators/financial_volatility_generator.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ from pyo import LFO, BrownNoise, Follower, Metro, Mix, Sine, TrigExpseg
5
+
6
+ from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
7
+ from src.synthetic_generation.audio_generators.utils import (
8
+ normalize_waveform,
9
+ run_offline_pyo,
10
+ )
11
+
12
+
13
+ class FinancialVolatilityAudioGenerator(AbstractTimeSeriesGenerator):
14
+ """
15
+ Generate synthetic univariate time series that mimics financial market
16
+ behavior with volatility clustering and occasional jumps.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ length: int,
22
+ server_duration: float,
23
+ sample_rate: int,
24
+ normalize_output: bool,
25
+ # Trend LFO
26
+ trend_lfo_freq_range: tuple[float, float],
27
+ trend_lfo_mul_range: tuple[float, float],
28
+ # Volatility clustering
29
+ volatility_carrier_freq_range: tuple[float, float],
30
+ follower_freq_range: tuple[float, float],
31
+ volatility_range: tuple[float, float],
32
+ # Jumps
33
+ jump_metro_time_range: tuple[float, float],
34
+ jump_env_start_range: tuple[float, float],
35
+ jump_env_decay_time_range: tuple[float, float],
36
+ jump_freq_range: tuple[float, float],
37
+ jump_direction_up_probability: float,
38
+ random_seed: Optional[int] = None,
39
+ ):
40
+ self.length = length
41
+ self.server_duration = server_duration
42
+ self.sample_rate = sample_rate
43
+ self.normalize_output = normalize_output
44
+
45
+ self.trend_lfo_freq_range = trend_lfo_freq_range
46
+ self.trend_lfo_mul_range = trend_lfo_mul_range
47
+ self.volatility_carrier_freq_range = volatility_carrier_freq_range
48
+ self.follower_freq_range = follower_freq_range
49
+ self.volatility_range = volatility_range
50
+ self.jump_metro_time_range = jump_metro_time_range
51
+ self.jump_env_start_range = jump_env_start_range
52
+ self.jump_env_decay_time_range = jump_env_decay_time_range
53
+ self.jump_freq_range = jump_freq_range
54
+ self.jump_direction_up_probability = jump_direction_up_probability
55
+
56
+ self.rng = np.random.default_rng(random_seed)
57
+
58
+ def _build_synth(self):
59
+ # Trend
60
+ trend_freq = self.rng.uniform(*self.trend_lfo_freq_range)
61
+ trend_mul = self.rng.uniform(*self.trend_lfo_mul_range)
62
+ trend = LFO(freq=trend_freq, type=0, mul=trend_mul)
63
+
64
+ # Volatility clustering
65
+ carrier_freq = self.rng.uniform(*self.volatility_carrier_freq_range)
66
+ follower_freq = self.rng.uniform(*self.follower_freq_range)
67
+ volatility_min, volatility_max = self.volatility_range
68
+ volatility_osc = Sine(freq=carrier_freq)
69
+ volatility = Follower(volatility_osc, freq=follower_freq).range(
70
+ volatility_min, volatility_max
71
+ )
72
+ market_noise = BrownNoise(mul=volatility)
73
+
74
+ # Jumps
75
+ jump_time = self.rng.uniform(*self.jump_metro_time_range)
76
+ jump_env_start = self.rng.uniform(*self.jump_env_start_range)
77
+ jump_env_decay = self.rng.uniform(*self.jump_env_decay_time_range)
78
+ jump_freq = self.rng.uniform(*self.jump_freq_range)
79
+ direction = (
80
+ 1.0 if self.rng.random() < self.jump_direction_up_probability else -1.0
81
+ )
82
+
83
+ jump_trigger = Metro(time=jump_time).play()
84
+ jump_env = TrigExpseg(
85
+ jump_trigger, list=[(0.0, jump_env_start), (jump_env_decay, 0.0)]
86
+ )
87
+ jumps = Sine(freq=jump_freq, mul=jump_env * direction)
88
+
89
+ return Mix([trend, market_noise, jumps], voices=1)
90
+
91
+ def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
92
+ if random_seed is not None:
93
+ self.rng = np.random.default_rng(random_seed)
94
+
95
+ waveform = run_offline_pyo(
96
+ synth_builder=self._build_synth,
97
+ server_duration=self.server_duration,
98
+ sample_rate=self.sample_rate,
99
+ length=self.length,
100
+ )
101
+ if self.normalize_output:
102
+ waveform = normalize_waveform(waveform)
103
+ return waveform
src/synthetic_generation/audio_generators/financial_volatility_wrapper.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import numpy as np
4
+
5
+ from src.data.containers import TimeSeriesContainer
6
+ from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
+ from src.synthetic_generation.audio_generators.financial_volatility_generator import (
8
+ FinancialVolatilityAudioGenerator,
9
+ )
10
+ from src.synthetic_generation.generator_params import FinancialVolatilityAudioParams
11
+
12
+
13
+ class FinancialVolatilityAudioWrapper(GeneratorWrapper):
14
+ def __init__(self, params: FinancialVolatilityAudioParams):
15
+ super().__init__(params)
16
+ self.params: FinancialVolatilityAudioParams = params
17
+
18
+ def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
19
+ params = super()._sample_parameters(batch_size)
20
+ params.update(
21
+ {
22
+ "length": self.params.length,
23
+ "server_duration": self.params.server_duration,
24
+ "sample_rate": self.params.sample_rate,
25
+ "normalize_output": self.params.normalize_output,
26
+ # Trend LFO
27
+ "trend_lfo_freq_range": self.params.trend_lfo_freq_range,
28
+ "trend_lfo_mul_range": self.params.trend_lfo_mul_range,
29
+ # Volatility clustering
30
+ "volatility_carrier_freq_range": self.params.volatility_carrier_freq_range,
31
+ "follower_freq_range": self.params.follower_freq_range,
32
+ "volatility_range": self.params.volatility_range,
33
+ # Jumps
34
+ "jump_metro_time_range": self.params.jump_metro_time_range,
35
+ "jump_env_start_range": self.params.jump_env_start_range,
36
+ "jump_env_decay_time_range": self.params.jump_env_decay_time_range,
37
+ "jump_freq_range": self.params.jump_freq_range,
38
+ "jump_direction_up_probability": self.params.jump_direction_up_probability,
39
+ }
40
+ )
41
+ return params
42
+
43
+ def generate_batch(
44
+ self,
45
+ batch_size: int,
46
+ seed: Optional[int] = None,
47
+ params: Optional[Dict[str, Any]] = None,
48
+ ) -> TimeSeriesContainer:
49
+ if seed is not None:
50
+ self._set_random_seeds(seed)
51
+ if params is None:
52
+ params = self._sample_parameters(batch_size)
53
+
54
+ generator = FinancialVolatilityAudioGenerator(
55
+ length=params["length"],
56
+ server_duration=params["server_duration"],
57
+ sample_rate=params["sample_rate"],
58
+ normalize_output=params["normalize_output"],
59
+ trend_lfo_freq_range=params["trend_lfo_freq_range"],
60
+ trend_lfo_mul_range=params["trend_lfo_mul_range"],
61
+ volatility_carrier_freq_range=params["volatility_carrier_freq_range"],
62
+ follower_freq_range=params["follower_freq_range"],
63
+ volatility_range=params["volatility_range"],
64
+ jump_metro_time_range=params["jump_metro_time_range"],
65
+ jump_env_start_range=params["jump_env_start_range"],
66
+ jump_env_decay_time_range=params["jump_env_decay_time_range"],
67
+ jump_freq_range=params["jump_freq_range"],
68
+ jump_direction_up_probability=params["jump_direction_up_probability"],
69
+ random_seed=seed,
70
+ )
71
+
72
+ def _derive_series_seed(base_seed: int, index: int) -> int:
73
+ # Mix base seed with index and class hash to decorrelate adjacent seeds
74
+ mixed = (
75
+ (base_seed & 0x7FFFFFFF)
76
+ ^ ((index * 0x9E3779B1) & 0x7FFFFFFF)
77
+ ^ (hash(self.__class__.__name__) & 0x7FFFFFFF)
78
+ )
79
+ return int(mixed)
80
+
81
+ batch_values = []
82
+ for i in range(batch_size):
83
+ series_seed = None if seed is None else _derive_series_seed(seed, i)
84
+ values = generator.generate_time_series(random_seed=series_seed)
85
+ batch_values.append(values)
86
+
87
+ return TimeSeriesContainer(
88
+ values=np.array(batch_values),
89
+ start=params["start"],
90
+ frequency=params["frequency"],
91
+ )
src/synthetic_generation/audio_generators/multi_scale_fractal_generator.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ from pyo import Biquad, BrownNoise, Mix
5
+
6
+ from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
7
+ from src.synthetic_generation.audio_generators.utils import (
8
+ normalize_waveform,
9
+ run_offline_pyo,
10
+ )
11
+
12
+
13
+ class MultiScaleFractalAudioGenerator(AbstractTimeSeriesGenerator):
14
+ """
15
+ Generate multi-scale fractal-like patterns by filtering noise at
16
+ multiple frequency bands with varying Q and attenuation per scale.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ length: int,
22
+ server_duration: float,
23
+ sample_rate: int,
24
+ normalize_output: bool,
25
+ base_noise_mul_range: tuple[float, float],
26
+ num_scales_range: tuple[int, int],
27
+ scale_freq_base_range: tuple[float, float],
28
+ q_factor_range: tuple[float, float],
29
+ per_scale_attenuation_range: tuple[float, float],
30
+ random_seed: Optional[int] = None,
31
+ ):
32
+ self.length = length
33
+ self.server_duration = server_duration
34
+ self.sample_rate = sample_rate
35
+ self.normalize_output = normalize_output
36
+
37
+ self.base_noise_mul_range = base_noise_mul_range
38
+ self.num_scales_range = num_scales_range
39
+ self.scale_freq_base_range = scale_freq_base_range
40
+ self.q_factor_range = q_factor_range
41
+ self.per_scale_attenuation_range = per_scale_attenuation_range
42
+
43
+ self.rng = np.random.default_rng(random_seed)
44
+
45
+ def _build_synth(self):
46
+ base_mul = self.rng.uniform(*self.base_noise_mul_range)
47
+ base = BrownNoise(mul=base_mul)
48
+
49
+ num_scales = int(
50
+ self.rng.integers(self.num_scales_range[0], self.num_scales_range[1] + 1)
51
+ )
52
+
53
+ scales = []
54
+ for i in range(num_scales):
55
+ scale_freq = self.rng.uniform(*self.scale_freq_base_range) * (0.5**i)
56
+ q_factor = self.rng.uniform(*self.q_factor_range)
57
+ per_scale_att = self.rng.uniform(*self.per_scale_attenuation_range)
58
+ filtered = Biquad(base, freq=scale_freq, q=q_factor, type=0)
59
+ scales.append(filtered * (per_scale_att**i))
60
+
61
+ return Mix(scales, voices=1)
62
+
63
+ def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
64
+ if random_seed is not None:
65
+ self.rng = np.random.default_rng(random_seed)
66
+
67
+ waveform = run_offline_pyo(
68
+ synth_builder=self._build_synth,
69
+ server_duration=self.server_duration,
70
+ sample_rate=self.sample_rate,
71
+ length=self.length,
72
+ )
73
+ if self.normalize_output:
74
+ waveform = normalize_waveform(waveform)
75
+ return waveform
src/synthetic_generation/audio_generators/multi_scale_fractal_wrapper.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import numpy as np
4
+
5
+ from src.data.containers import TimeSeriesContainer
6
+ from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
+ from src.synthetic_generation.audio_generators.multi_scale_fractal_generator import (
8
+ MultiScaleFractalAudioGenerator,
9
+ )
10
+ from src.synthetic_generation.generator_params import MultiScaleFractalAudioParams
11
+
12
+
13
+ class MultiScaleFractalAudioWrapper(GeneratorWrapper):
14
+ def __init__(self, params: MultiScaleFractalAudioParams):
15
+ super().__init__(params)
16
+ self.params: MultiScaleFractalAudioParams = params
17
+
18
+ def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
19
+ params = super()._sample_parameters(batch_size)
20
+ params.update(
21
+ {
22
+ "length": self.params.length,
23
+ "server_duration": self.params.server_duration,
24
+ "sample_rate": self.params.sample_rate,
25
+ "normalize_output": self.params.normalize_output,
26
+ "base_noise_mul_range": self.params.base_noise_mul_range,
27
+ "num_scales_range": self.params.num_scales_range,
28
+ "scale_freq_base_range": self.params.scale_freq_base_range,
29
+ "q_factor_range": self.params.q_factor_range,
30
+ "per_scale_attenuation_range": self.params.per_scale_attenuation_range,
31
+ }
32
+ )
33
+ return params
34
+
35
+ def generate_batch(
36
+ self,
37
+ batch_size: int,
38
+ seed: Optional[int] = None,
39
+ params: Optional[Dict[str, Any]] = None,
40
+ ) -> TimeSeriesContainer:
41
+ if seed is not None:
42
+ self._set_random_seeds(seed)
43
+ if params is None:
44
+ params = self._sample_parameters(batch_size)
45
+
46
+ generator = MultiScaleFractalAudioGenerator(
47
+ length=params["length"],
48
+ server_duration=params["server_duration"],
49
+ sample_rate=params["sample_rate"],
50
+ normalize_output=params["normalize_output"],
51
+ base_noise_mul_range=params["base_noise_mul_range"],
52
+ num_scales_range=params["num_scales_range"],
53
+ scale_freq_base_range=params["scale_freq_base_range"],
54
+ q_factor_range=params["q_factor_range"],
55
+ per_scale_attenuation_range=params["per_scale_attenuation_range"],
56
+ random_seed=seed,
57
+ )
58
+
59
+ def _derive_series_seed(base_seed: int, index: int) -> int:
60
+ mixed = (
61
+ (base_seed & 0x7FFFFFFF)
62
+ ^ ((index * 0x9E3779B1) & 0x7FFFFFFF)
63
+ ^ (hash(self.__class__.__name__) & 0x7FFFFFFF)
64
+ )
65
+ return int(mixed)
66
+
67
+ batch_values = []
68
+ for i in range(batch_size):
69
+ series_seed = None if seed is None else _derive_series_seed(seed, i)
70
+ values = generator.generate_time_series(random_seed=series_seed)
71
+ batch_values.append(values)
72
+
73
+ return TimeSeriesContainer(
74
+ values=np.array(batch_values),
75
+ start=params["start"],
76
+ frequency=params["frequency"],
77
+ )
src/synthetic_generation/audio_generators/network_topology_generator.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ from pyo import LFO, BrownNoise, Metro, Mix, Noise, TrigExpseg
5
+
6
+ from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
7
+ from src.synthetic_generation.audio_generators.utils import (
8
+ normalize_waveform,
9
+ run_offline_pyo,
10
+ )
11
+
12
+
13
+ class NetworkTopologyAudioGenerator(AbstractTimeSeriesGenerator):
14
+ """
15
+ Simulate network traffic with base flow, packet bursts, periodic congestion,
16
+ protocol overhead, and DDoS-like attacks. Parameters are sampled per series.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ length: int,
22
+ server_duration: float,
23
+ sample_rate: int,
24
+ normalize_output: bool,
25
+ traffic_lfo_freq_range: tuple[float, float],
26
+ traffic_lfo_mul_range: tuple[float, float],
27
+ burst_rate_hz_range: tuple[float, float],
28
+ burst_duration_range: tuple[float, float],
29
+ burst_mul_range: tuple[float, float],
30
+ congestion_period_range: tuple[float, float],
31
+ congestion_depth_range: tuple[float, float],
32
+ congestion_release_time_range: tuple[float, float],
33
+ overhead_lfo_freq_range: tuple[float, float],
34
+ overhead_mul_range: tuple[float, float],
35
+ attack_period_range: tuple[float, float],
36
+ attack_env_points: Tuple[
37
+ Tuple[float, float], Tuple[float, float], Tuple[float, float]
38
+ ],
39
+ attack_mul_range: tuple[float, float],
40
+ random_seed: Optional[int] = None,
41
+ ):
42
+ self.length = length
43
+ self.server_duration = server_duration
44
+ self.sample_rate = sample_rate
45
+ self.normalize_output = normalize_output
46
+
47
+ self.traffic_lfo_freq_range = traffic_lfo_freq_range
48
+ self.traffic_lfo_mul_range = traffic_lfo_mul_range
49
+ self.burst_rate_hz_range = burst_rate_hz_range
50
+ self.burst_duration_range = burst_duration_range
51
+ self.burst_mul_range = burst_mul_range
52
+ self.congestion_period_range = congestion_period_range
53
+ self.congestion_depth_range = congestion_depth_range
54
+ self.congestion_release_time_range = congestion_release_time_range
55
+ self.overhead_lfo_freq_range = overhead_lfo_freq_range
56
+ self.overhead_mul_range = overhead_mul_range
57
+ self.attack_period_range = attack_period_range
58
+ self.attack_env_points = attack_env_points
59
+ self.attack_mul_range = attack_mul_range
60
+
61
+ self.rng = np.random.default_rng(random_seed)
62
+
63
+ def _build_synth(self):
64
+ # Base traffic flow
65
+ traffic_freq = self.rng.uniform(*self.traffic_lfo_freq_range)
66
+ traffic_mul = self.rng.uniform(*self.traffic_lfo_mul_range)
67
+ traffic_base = LFO(freq=traffic_freq, type=0, mul=traffic_mul)
68
+
69
+ # Packet bursts
70
+ burst_rate = self.rng.uniform(*self.burst_rate_hz_range)
71
+ burst_trigger = Metro(time=1.0 / burst_rate).play()
72
+ burst_duration = self.rng.uniform(*self.burst_duration_range)
73
+ burst_env = TrigExpseg(burst_trigger, list=[(0.0, 0.8), (burst_duration, 0.0)])
74
+ burst_mul = self.rng.uniform(*self.burst_mul_range)
75
+ bursts = Noise(mul=burst_env * burst_mul)
76
+
77
+ # Periodic congestion (negative amplitude dip)
78
+ congestion_period = self.rng.uniform(*self.congestion_period_range)
79
+ congestion_trigger = Metro(time=congestion_period).play()
80
+ congestion_depth = self.rng.uniform(*self.congestion_depth_range) # negative
81
+ congestion_release = self.rng.uniform(*self.congestion_release_time_range)
82
+ congestion_env = TrigExpseg(
83
+ congestion_trigger,
84
+ list=[(0.0, congestion_depth), (congestion_release, 0.0)],
85
+ )
86
+
87
+ # Protocol overhead
88
+ overhead_freq = self.rng.uniform(*self.overhead_lfo_freq_range)
89
+ overhead_mul = self.rng.uniform(*self.overhead_mul_range)
90
+ overhead = LFO(freq=overhead_freq, type=1, mul=overhead_mul)
91
+
92
+ # DDoS-like attacks
93
+ attack_period = self.rng.uniform(*self.attack_period_range)
94
+ attack_trigger = Metro(time=attack_period).play()
95
+ attack_env = TrigExpseg(attack_trigger, list=list(self.attack_env_points))
96
+ attack_mul = self.rng.uniform(*self.attack_mul_range)
97
+ attacks = BrownNoise(mul=attack_env * attack_mul)
98
+
99
+ return Mix([traffic_base, bursts, congestion_env, overhead, attacks], voices=1)
100
+
101
+ def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
102
+ if random_seed is not None:
103
+ self.rng = np.random.default_rng(random_seed)
104
+
105
+ waveform = run_offline_pyo(
106
+ synth_builder=self._build_synth,
107
+ server_duration=self.server_duration,
108
+ sample_rate=self.sample_rate,
109
+ length=self.length,
110
+ )
111
+ if self.normalize_output:
112
+ waveform = normalize_waveform(waveform)
113
+ return waveform
src/synthetic_generation/audio_generators/network_topology_wrapper.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import numpy as np
4
+
5
+ from src.data.containers import TimeSeriesContainer
6
+ from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
+ from src.synthetic_generation.audio_generators.network_topology_generator import (
8
+ NetworkTopologyAudioGenerator,
9
+ )
10
+ from src.synthetic_generation.generator_params import NetworkTopologyAudioParams
11
+
12
+
13
+ class NetworkTopologyAudioWrapper(GeneratorWrapper):
14
+ def __init__(self, params: NetworkTopologyAudioParams):
15
+ super().__init__(params)
16
+ self.params: NetworkTopologyAudioParams = params
17
+
18
+ def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
19
+ params = super()._sample_parameters(batch_size)
20
+ params.update(
21
+ {
22
+ "length": self.params.length,
23
+ "server_duration": self.params.server_duration,
24
+ "sample_rate": self.params.sample_rate,
25
+ "normalize_output": self.params.normalize_output,
26
+ "traffic_lfo_freq_range": self.params.traffic_lfo_freq_range,
27
+ "traffic_lfo_mul_range": self.params.traffic_lfo_mul_range,
28
+ "burst_rate_hz_range": self.params.burst_rate_hz_range,
29
+ "burst_duration_range": self.params.burst_duration_range,
30
+ "burst_mul_range": self.params.burst_mul_range,
31
+ "congestion_period_range": self.params.congestion_period_range,
32
+ "congestion_depth_range": self.params.congestion_depth_range,
33
+ "congestion_release_time_range": self.params.congestion_release_time_range,
34
+ "overhead_lfo_freq_range": self.params.overhead_lfo_freq_range,
35
+ "overhead_mul_range": self.params.overhead_mul_range,
36
+ "attack_period_range": self.params.attack_period_range,
37
+ "attack_env_points": self.params.attack_env_points,
38
+ "attack_mul_range": self.params.attack_mul_range,
39
+ }
40
+ )
41
+ return params
42
+
43
+ def generate_batch(
44
+ self,
45
+ batch_size: int,
46
+ seed: Optional[int] = None,
47
+ params: Optional[Dict[str, Any]] = None,
48
+ ) -> TimeSeriesContainer:
49
+ if seed is not None:
50
+ self._set_random_seeds(seed)
51
+ if params is None:
52
+ params = self._sample_parameters(batch_size)
53
+
54
+ generator = NetworkTopologyAudioGenerator(
55
+ length=params["length"],
56
+ server_duration=params["server_duration"],
57
+ sample_rate=params["sample_rate"],
58
+ normalize_output=params["normalize_output"],
59
+ traffic_lfo_freq_range=params["traffic_lfo_freq_range"],
60
+ traffic_lfo_mul_range=params["traffic_lfo_mul_range"],
61
+ burst_rate_hz_range=params["burst_rate_hz_range"],
62
+ burst_duration_range=params["burst_duration_range"],
63
+ burst_mul_range=params["burst_mul_range"],
64
+ congestion_period_range=params["congestion_period_range"],
65
+ congestion_depth_range=params["congestion_depth_range"],
66
+ congestion_release_time_range=params["congestion_release_time_range"],
67
+ overhead_lfo_freq_range=params["overhead_lfo_freq_range"],
68
+ overhead_mul_range=params["overhead_mul_range"],
69
+ attack_period_range=params["attack_period_range"],
70
+ attack_env_points=params["attack_env_points"],
71
+ attack_mul_range=params["attack_mul_range"],
72
+ random_seed=seed,
73
+ )
74
+
75
+ def _derive_series_seed(base_seed: int, index: int) -> int:
76
+ mixed = (
77
+ (base_seed & 0x7FFFFFFF)
78
+ ^ ((index * 0x9E3779B1) & 0x7FFFFFFF)
79
+ ^ (hash(self.__class__.__name__) & 0x7FFFFFFF)
80
+ )
81
+ return int(mixed)
82
+
83
+ batch_values = []
84
+ for i in range(batch_size):
85
+ series_seed = None if seed is None else _derive_series_seed(seed, i)
86
+ values = generator.generate_time_series(random_seed=series_seed)
87
+ batch_values.append(values)
88
+
89
+ return TimeSeriesContainer(
90
+ values=np.array(batch_values),
91
+ start=params["start"],
92
+ frequency=params["frequency"],
93
+ )
src/synthetic_generation/audio_generators/stochastic_rhythm_generator.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ from pyo import Metro, Mix, Sine, TrigExpseg
5
+
6
+ from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
7
+ from src.synthetic_generation.audio_generators.utils import (
8
+ normalize_waveform,
9
+ run_offline_pyo,
10
+ )
11
+
12
+
13
+ class StochasticRhythmAudioGenerator(AbstractTimeSeriesGenerator):
14
+ """
15
+ Generate rhythmic patterns with layered triggers, per-layer envelopes
16
+ and tones. Parameters are sampled per series for diversity.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ length: int,
22
+ server_duration: float,
23
+ sample_rate: int,
24
+ normalize_output: bool,
25
+ base_tempo_hz_range: tuple[float, float],
26
+ num_layers_range: tuple[int, int],
27
+ subdivisions: tuple[int, ...],
28
+ attack_range: tuple[float, float],
29
+ decay_range: tuple[float, float],
30
+ tone_freq_range: tuple[float, float],
31
+ tone_mul_range: tuple[float, float],
32
+ random_seed: Optional[int] = None,
33
+ ):
34
+ self.length = length
35
+ self.server_duration = server_duration
36
+ self.sample_rate = sample_rate
37
+ self.normalize_output = normalize_output
38
+
39
+ self.base_tempo_hz_range = base_tempo_hz_range
40
+ self.num_layers_range = num_layers_range
41
+ self.subdivisions = subdivisions
42
+ self.attack_range = attack_range
43
+ self.decay_range = decay_range
44
+ self.tone_freq_range = tone_freq_range
45
+ self.tone_mul_range = tone_mul_range
46
+
47
+ self.rng = np.random.default_rng(random_seed)
48
+
49
+ def _build_synth(self):
50
+ base_tempo = self.rng.uniform(*self.base_tempo_hz_range)
51
+ num_layers = int(
52
+ self.rng.integers(self.num_layers_range[0], self.num_layers_range[1] + 1)
53
+ )
54
+
55
+ layers = []
56
+ for _ in range(num_layers):
57
+ subdivision = self.subdivisions[
58
+ int(self.rng.integers(0, len(self.subdivisions)))
59
+ ]
60
+ rhythm_freq = base_tempo * subdivision
61
+ trigger = Metro(time=1.0 / rhythm_freq).play()
62
+
63
+ attack = self.rng.uniform(*self.attack_range)
64
+ decay = self.rng.uniform(*self.decay_range)
65
+ env = TrigExpseg(trigger, list=[(0.0, 1.0), (attack, 0.8), (decay, 0.0)])
66
+
67
+ tone_freq = self.rng.uniform(*self.tone_freq_range)
68
+ tone_mul = self.rng.uniform(*self.tone_mul_range)
69
+ tone = Sine(freq=tone_freq, mul=env * tone_mul)
70
+ layers.append(tone)
71
+
72
+ return Mix(layers, voices=1)
73
+
74
+ def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
75
+ if random_seed is not None:
76
+ self.rng = np.random.default_rng(random_seed)
77
+
78
+ waveform = run_offline_pyo(
79
+ synth_builder=self._build_synth,
80
+ server_duration=self.server_duration,
81
+ sample_rate=self.sample_rate,
82
+ length=self.length,
83
+ )
84
+ if self.normalize_output:
85
+ waveform = normalize_waveform(waveform)
86
+ return waveform
src/synthetic_generation/audio_generators/stochastic_rhythm_wrapper.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import numpy as np
4
+
5
+ from src.data.containers import TimeSeriesContainer
6
+ from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
+ from src.synthetic_generation.audio_generators.stochastic_rhythm_generator import (
8
+ StochasticRhythmAudioGenerator,
9
+ )
10
+ from src.synthetic_generation.generator_params import StochasticRhythmAudioParams
11
+
12
+
13
+ class StochasticRhythmAudioWrapper(GeneratorWrapper):
14
+ def __init__(self, params: StochasticRhythmAudioParams):
15
+ super().__init__(params)
16
+ self.params: StochasticRhythmAudioParams = params
17
+
18
+ def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
19
+ params = super()._sample_parameters(batch_size)
20
+ params.update(
21
+ {
22
+ "length": self.params.length,
23
+ "server_duration": self.params.server_duration,
24
+ "sample_rate": self.params.sample_rate,
25
+ "normalize_output": self.params.normalize_output,
26
+ "base_tempo_hz_range": self.params.base_tempo_hz_range,
27
+ "num_layers_range": self.params.num_layers_range,
28
+ "subdivisions": self.params.subdivisions,
29
+ "attack_range": self.params.attack_range,
30
+ "decay_range": self.params.decay_range,
31
+ "tone_freq_range": self.params.tone_freq_range,
32
+ "tone_mul_range": self.params.tone_mul_range,
33
+ }
34
+ )
35
+ return params
36
+
37
+ def generate_batch(
38
+ self,
39
+ batch_size: int,
40
+ seed: Optional[int] = None,
41
+ params: Optional[Dict[str, Any]] = None,
42
+ ) -> TimeSeriesContainer:
43
+ if seed is not None:
44
+ self._set_random_seeds(seed)
45
+ if params is None:
46
+ params = self._sample_parameters(batch_size)
47
+
48
+ generator = StochasticRhythmAudioGenerator(
49
+ length=params["length"],
50
+ server_duration=params["server_duration"],
51
+ sample_rate=params["sample_rate"],
52
+ normalize_output=params["normalize_output"],
53
+ base_tempo_hz_range=params["base_tempo_hz_range"],
54
+ num_layers_range=params["num_layers_range"],
55
+ subdivisions=params["subdivisions"],
56
+ attack_range=params["attack_range"],
57
+ decay_range=params["decay_range"],
58
+ tone_freq_range=params["tone_freq_range"],
59
+ tone_mul_range=params["tone_mul_range"],
60
+ random_seed=seed,
61
+ )
62
+
63
+ def _derive_series_seed(base_seed: int, index: int) -> int:
64
+ mixed = (
65
+ (base_seed & 0x7FFFFFFF)
66
+ ^ ((index * 0x9E3779B1) & 0x7FFFFFFF)
67
+ ^ (hash(self.__class__.__name__) & 0x7FFFFFFF)
68
+ )
69
+ return int(mixed)
70
+
71
+ batch_values = []
72
+ for i in range(batch_size):
73
+ series_seed = None if seed is None else _derive_series_seed(seed, i)
74
+ values = generator.generate_time_series(random_seed=series_seed)
75
+ batch_values.append(values)
76
+
77
+ return TimeSeriesContainer(
78
+ values=np.array(batch_values),
79
+ start=params["start"],
80
+ frequency=params["frequency"],
81
+ )