tempoPFN / src /data /time_features.py
Vladyslav Moroshan
Apply ruff formatting
96e1a32
import logging
from typing import Any
import numpy as np
import pandas as pd
import scipy.fft as fft
import torch
from gluonts.time_feature import time_features_from_frequency_str
from gluonts.time_feature._base import (
day_of_month,
day_of_month_index,
day_of_week,
day_of_week_index,
day_of_year,
hour_of_day,
hour_of_day_index,
minute_of_hour,
minute_of_hour_index,
month_of_year,
month_of_year_index,
second_of_minute,
second_of_minute_index,
week_of_year,
week_of_year_index,
)
from gluonts.time_feature.holiday import (
BLACK_FRIDAY,
CHRISTMAS_DAY,
CHRISTMAS_EVE,
CYBER_MONDAY,
EASTER_MONDAY,
EASTER_SUNDAY,
GOOD_FRIDAY,
INDEPENDENCE_DAY,
LABOR_DAY,
MEMORIAL_DAY,
NEW_YEARS_DAY,
NEW_YEARS_EVE,
THANKSGIVING,
SpecialDateFeatureSet,
exponential_kernel,
squared_exponential_kernel,
)
from gluonts.time_feature.seasonality import get_seasonality
from scipy.signal import find_peaks
from src.data.constants import BASE_END_DATE, BASE_START_DATE
from src.data.frequency import (
Frequency,
validate_frequency_safety,
)
from src.utils.utils import device
# Configure logging
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Enhanced feature sets for different frequencies
ENHANCED_TIME_FEATURES = {
# High-frequency features (seconds, minutes)
"high_freq": {
"normalized": [
second_of_minute,
minute_of_hour,
hour_of_day,
day_of_week,
day_of_month,
],
"index": [
second_of_minute_index,
minute_of_hour_index,
hour_of_day_index,
day_of_week_index,
],
},
# Medium-frequency features (hourly, daily)
"medium_freq": {
"normalized": [
hour_of_day,
day_of_week,
day_of_month,
day_of_year,
month_of_year,
],
"index": [
hour_of_day_index,
day_of_week_index,
day_of_month_index,
week_of_year_index,
],
},
# Low-frequency features (weekly, monthly)
"low_freq": {
"normalized": [day_of_week, day_of_month, month_of_year, week_of_year],
"index": [day_of_week_index, month_of_year_index, week_of_year_index],
},
}
# Holiday features for different markets/regions
HOLIDAY_FEATURE_SETS = {
"us_business": [
NEW_YEARS_DAY,
MEMORIAL_DAY,
INDEPENDENCE_DAY,
LABOR_DAY,
THANKSGIVING,
CHRISTMAS_EVE,
CHRISTMAS_DAY,
NEW_YEARS_EVE,
],
"us_retail": [
NEW_YEARS_DAY,
EASTER_SUNDAY,
MEMORIAL_DAY,
INDEPENDENCE_DAY,
LABOR_DAY,
THANKSGIVING,
BLACK_FRIDAY,
CYBER_MONDAY,
CHRISTMAS_EVE,
CHRISTMAS_DAY,
NEW_YEARS_EVE,
],
"christian": [
NEW_YEARS_DAY,
GOOD_FRIDAY,
EASTER_SUNDAY,
EASTER_MONDAY,
CHRISTMAS_EVE,
CHRISTMAS_DAY,
NEW_YEARS_EVE,
],
}
class TimeFeatureGenerator:
"""
Enhanced time feature generator that leverages full GluonTS capabilities.
"""
def __init__(
self,
use_enhanced_features: bool = True,
use_holiday_features: bool = True,
holiday_set: str = "us_business",
holiday_kernel: str = "exponential",
holiday_kernel_alpha: float = 1.0,
use_index_features: bool = True,
k_max: int = 15,
include_seasonality_info: bool = True,
use_auto_seasonality: bool = False, # New parameter
max_seasonal_periods: int = 3, # New parameter
):
"""
Initialize enhanced time feature generator.
Parameters
----------
use_enhanced_features : bool
Whether to use frequency-specific enhanced features
use_holiday_features : bool
Whether to include holiday features
holiday_set : str
Which holiday set to use ('us_business', 'us_retail', 'christian')
holiday_kernel : str
Holiday kernel type ('indicator', 'exponential', 'squared_exponential')
holiday_kernel_alpha : float
Kernel parameter for exponential kernels
use_index_features : bool
Whether to include index-based features alongside normalized ones
k_max : int
Maximum number of time features to pad to
include_seasonality_info : bool
Whether to include seasonality information as features
use_auto_seasonality : bool
Whether to use automatic FFT-based seasonality detection
max_seasonal_periods : int
Maximum number of seasonal periods to detect automatically
"""
self.use_enhanced_features = use_enhanced_features
self.use_holiday_features = use_holiday_features
self.holiday_set = holiday_set
self.use_index_features = use_index_features
self.k_max = k_max
self.include_seasonality_info = include_seasonality_info
self.use_auto_seasonality = use_auto_seasonality
self.max_seasonal_periods = max_seasonal_periods
# Initialize holiday feature set
self.holiday_feature_set = None
if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS:
kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha)
self.holiday_feature_set = SpecialDateFeatureSet(HOLIDAY_FEATURE_SETS[holiday_set], kernel_func)
def _get_holiday_kernel(self, kernel_type: str, alpha: float):
"""Get holiday kernel function."""
if kernel_type == "exponential":
return exponential_kernel(alpha)
elif kernel_type == "squared_exponential":
return squared_exponential_kernel(alpha)
else:
# Default indicator kernel
return lambda x: float(x == 0)
def _get_feature_category(self, freq_str: str) -> str:
"""Determine feature category based on frequency."""
if freq_str in ["s", "1min", "5min", "10min", "15min"]:
return "high_freq"
elif freq_str in ["h", "D"]:
return "medium_freq"
else:
return "low_freq"
def _compute_enhanced_features(self, period_index: pd.PeriodIndex, freq_str: str) -> np.ndarray:
"""Compute enhanced time features based on frequency."""
if not self.use_enhanced_features:
return np.array([]).reshape(len(period_index), 0)
category = self._get_feature_category(freq_str)
feature_config = ENHANCED_TIME_FEATURES[category]
features = []
# Add normalized features
for feat_func in feature_config["normalized"]:
try:
feat_values = feat_func(period_index)
features.append(feat_values)
except Exception:
continue
# Add index features if enabled
if self.use_index_features:
for feat_func in feature_config["index"]:
try:
feat_values = feat_func(period_index)
# Normalize index features to [0, 1] range
if feat_values.max() > 0:
feat_values = feat_values / feat_values.max()
features.append(feat_values)
except Exception:
continue
if features:
return np.stack(features, axis=-1)
else:
return np.array([]).reshape(len(period_index), 0)
def _compute_holiday_features(self, date_range: pd.DatetimeIndex) -> np.ndarray:
"""Compute holiday features."""
if not self.use_holiday_features or self.holiday_feature_set is None:
return np.array([]).reshape(len(date_range), 0)
try:
holiday_features = self.holiday_feature_set(date_range)
return holiday_features.T # Transpose to get [time, features] shape
except Exception:
return np.array([]).reshape(len(date_range), 0)
def _detect_auto_seasonality(self, time_series_values: np.ndarray) -> list:
"""
Detect seasonal periods automatically using FFT analysis.
Parameters
----------
time_series_values : np.ndarray
Time series values for seasonality detection
Returns
-------
list
List of detected seasonal periods
"""
if not self.use_auto_seasonality or len(time_series_values) < 10:
return []
try:
# Remove NaN values
values = time_series_values[~np.isnan(time_series_values)]
if len(values) < 10:
return []
# Simple linear detrending
x = np.arange(len(values))
coeffs = np.polyfit(x, values, 1)
trend = np.polyval(coeffs, x)
detrended = values - trend
# Apply Hann window to reduce spectral leakage
window = np.hanning(len(detrended))
windowed = detrended * window
# Zero padding for better frequency resolution
padded_length = len(windowed) * 2
padded_values = np.zeros(padded_length)
padded_values[: len(windowed)] = windowed
# Compute FFT
fft_values = fft.rfft(padded_values)
fft_magnitudes = np.abs(fft_values)
freqs = np.fft.rfftfreq(padded_length)
# Exclude DC component
fft_magnitudes[0] = 0.0
# Find peaks with threshold (5% of max magnitude)
threshold = 0.05 * np.max(fft_magnitudes)
peak_indices, _ = find_peaks(fft_magnitudes, height=threshold)
if len(peak_indices) == 0:
return []
# Sort by magnitude and take top periods
sorted_indices = peak_indices[np.argsort(fft_magnitudes[peak_indices])[::-1]]
top_indices = sorted_indices[: self.max_seasonal_periods]
# Convert frequencies to periods
periods = []
for idx in top_indices:
if freqs[idx] > 0:
period = 1.0 / freqs[idx]
# Scale back to original length and round
period = round(period / 2) # Account for zero padding
if 2 <= period <= len(values) // 2: # Reasonable period range
periods.append(period)
return list(set(periods)) # Remove duplicates
except Exception:
return []
def _compute_seasonality_features(
self,
period_index: pd.PeriodIndex,
freq_str: str,
time_series_values: np.ndarray = None,
) -> np.ndarray:
"""Compute seasonality-aware features."""
if not self.include_seasonality_info:
return np.array([]).reshape(len(period_index), 0)
all_seasonal_features = []
# Original frequency-based seasonality
try:
seasonality = get_seasonality(freq_str)
if seasonality > 1:
positions = np.arange(len(period_index))
sin_feat = np.sin(2 * np.pi * positions / seasonality)
cos_feat = np.cos(2 * np.pi * positions / seasonality)
all_seasonal_features.extend([sin_feat, cos_feat])
except Exception:
pass
# Automatic seasonality detection
if self.use_auto_seasonality and time_series_values is not None:
auto_periods = self._detect_auto_seasonality(time_series_values)
for period in auto_periods:
try:
positions = np.arange(len(period_index))
sin_feat = np.sin(2 * np.pi * positions / period)
cos_feat = np.cos(2 * np.pi * positions / period)
all_seasonal_features.extend([sin_feat, cos_feat])
except Exception:
continue
if all_seasonal_features:
return np.stack(all_seasonal_features, axis=-1)
else:
return np.array([]).reshape(len(period_index), 0)
def compute_features(
self,
period_index: pd.PeriodIndex,
date_range: pd.DatetimeIndex,
freq_str: str,
time_series_values: np.ndarray = None,
) -> np.ndarray:
"""
Compute all time features for given period index.
Parameters
----------
period_index : pd.PeriodIndex
Period index for computing features
date_range : pd.DatetimeIndex
Corresponding datetime index for holiday features
freq_str : str
Frequency string
time_series_values : np.ndarray, optional
Time series values for automatic seasonality detection
Returns
-------
np.ndarray
Time features array of shape [time_steps, num_features]
"""
all_features = []
# Standard GluonTS features
try:
standard_features = time_features_from_frequency_str(freq_str)
if standard_features:
std_feat = np.stack([feat(period_index) for feat in standard_features], axis=-1)
all_features.append(std_feat)
except Exception:
pass
# Enhanced features
enhanced_feat = self._compute_enhanced_features(period_index, freq_str)
if enhanced_feat.shape[1] > 0:
all_features.append(enhanced_feat)
# Holiday features
holiday_feat = self._compute_holiday_features(date_range)
if holiday_feat.shape[1] > 0:
all_features.append(holiday_feat)
# Seasonality features (including auto-detected)
seasonality_feat = self._compute_seasonality_features(period_index, freq_str, time_series_values)
if seasonality_feat.shape[1] > 0:
all_features.append(seasonality_feat)
if all_features:
combined_features = np.concatenate(all_features, axis=-1)
else:
combined_features = np.zeros((len(period_index), 1))
return combined_features
def compute_batch_time_features(
start: list[np.datetime64],
history_length: int,
future_length: int,
batch_size: int,
frequency: list[Frequency],
K_max: int = 6,
time_feature_config: dict[str, Any] | None = None,
):
"""
Compute time features from start timestamps and frequency.
Parameters
----------
start : array-like, shape (batch_size,)
Start timestamps for each batch item.
history_length : int
Length of history sequence.
future_length : int
Length of target sequence.
batch_size : int
Batch size.
frequency : array-like, shape (batch_size,)
Frequency of the time series.
K_max : int, optional
Maximum number of time features to pad to (default: 6).
time_feature_config : dict, optional
Configuration for enhanced time features.
Returns
-------
tuple
(history_time_features, target_time_features) where each is a torch.Tensor
of shape (batch_size, length, K_max).
"""
# Initialize enhanced feature generator
feature_config = time_feature_config or {}
feature_generator = TimeFeatureGenerator(**feature_config)
# Generate timestamps and features
history_features_list = []
future_features_list = []
total_length = history_length + future_length
for i in range(batch_size):
frequency_i = frequency[i]
freq_str = frequency_i.to_pandas_freq(for_date_range=True)
period_freq_str = frequency_i.to_pandas_freq(for_date_range=False)
# Validate start timestamp is within safe bounds
start_ts = pd.Timestamp(start[i])
if not validate_frequency_safety(start_ts, total_length, frequency_i):
logger.debug(
f"Start date {start_ts} not safe for total_length={total_length}, frequency={frequency_i}. "
f"Using BASE_START_DATE instead."
)
start_ts = BASE_START_DATE
# Create history range with bounds checking
history_range = pd.date_range(start=start_ts, periods=history_length, freq=freq_str)
# Check if history range goes beyond safe bounds
if history_range[-1] > BASE_END_DATE:
safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (history_length + future_length)
if safe_start < BASE_START_DATE:
safe_start = BASE_START_DATE
history_range = pd.date_range(start=safe_start, periods=history_length, freq=freq_str)
future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str)
future_range = pd.date_range(start=future_start, periods=future_length, freq=freq_str)
# Convert to period indices
history_period_idx = history_range.to_period(period_freq_str)
future_period_idx = future_range.to_period(period_freq_str)
# Compute enhanced features
history_features = feature_generator.compute_features(history_period_idx, history_range, freq_str)
future_features = feature_generator.compute_features(future_period_idx, future_range, freq_str)
# Pad or truncate to K_max
history_features = _pad_or_truncate_features(history_features, K_max)
future_features = _pad_or_truncate_features(future_features, K_max)
history_features_list.append(history_features)
future_features_list.append(future_features)
# Stack into batch tensors
history_time_features = np.stack(history_features_list, axis=0)
future_time_features = np.stack(future_features_list, axis=0)
return (
torch.from_numpy(history_time_features).float().to(device),
torch.from_numpy(future_time_features).float().to(device),
)
def _pad_or_truncate_features(features: np.ndarray, K_max: int) -> np.ndarray:
"""Pad with zeros or truncate features to K_max dimensions."""
seq_len, num_features = features.shape
if num_features < K_max:
# Pad with zeros
padding = np.zeros((seq_len, K_max - num_features))
features = np.concatenate([features, padding], axis=-1)
elif num_features > K_max:
# Truncate to K_max (keep most important features first)
features = features[:, :K_max]
return features