|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import yaml |
|
|
from huggingface_hub import hf_hub_download |
|
|
import spaces |
|
|
import traceback |
|
|
import functools |
|
|
import yfinance as yf |
|
|
import pandas_ta as ta |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
from scipy import stats |
|
|
|
|
|
|
|
|
from examples.utils import load_model |
|
|
from src.plotting.plot_timeseries import plot_multivariate_timeseries |
|
|
from src.data.containers import BatchTimeSeriesContainer, Frequency |
|
|
from src.synthetic_generation.generator_params import ( |
|
|
SineWaveGeneratorParams, GPGeneratorParams, AnomalyGeneratorParams, |
|
|
MultiScaleFractalAudioParams, FinancialVolatilityAudioParams, |
|
|
SawToothGeneratorParams, SpikesGeneratorParams, StepGeneratorParams, |
|
|
OrnsteinUhlenbeckProcessGeneratorParams, NetworkTopologyAudioParams, |
|
|
StochasticRhythmAudioParams, CauKerGeneratorParams, ForecastPFNGeneratorParams, |
|
|
KernelGeneratorParams |
|
|
) |
|
|
|
|
|
|
|
|
ALL_DATASETS = ["ETTm1", "ETTm2", "ETTh1", "ETTh2", "Weather", "Electricity", "Traffic"] |
|
|
TERMS = ["short", "medium", "long"] |
|
|
|
|
|
|
|
|
try: |
|
|
from src.gift_eval.evaluate import evaluate_datasets |
|
|
from src.gift_eval.predictor import TimeSeriesPredictor |
|
|
from src.gift_eval.results import aggregate_results |
|
|
from src.gift_eval.constants import ALL_DATASETS, TERMS |
|
|
GIFT_EVAL_AVAILABLE = True |
|
|
except ImportError: |
|
|
GIFT_EVAL_AVAILABLE = False |
|
|
print("Warning: GIFT evaluation dependencies not available. GIFT evaluation tab will be disabled.") |
|
|
from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import SineWaveGeneratorWrapper |
|
|
from src.synthetic_generation.anomalies.anomaly_generator_wrapper import AnomalyGeneratorWrapper |
|
|
from src.synthetic_generation.sawtooth.sawtooth_generator_wrapper import SawToothGeneratorWrapper |
|
|
from src.synthetic_generation.spikes.spikes_generator_wrapper import SpikesGeneratorWrapper |
|
|
from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper |
|
|
from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import OrnsteinUhlenbeckProcessGeneratorWrapper |
|
|
|
|
|
|
|
|
try: |
|
|
from src.synthetic_generation.audio_generators.network_topology_wrapper import NetworkTopologyAudioWrapper |
|
|
NETWORK_AVAILABLE = True |
|
|
except ImportError: |
|
|
NETWORK_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from src.synthetic_generation.audio_generators.stochastic_rhythm_wrapper import StochasticRhythmAudioWrapper |
|
|
RHYTHM_AVAILABLE = True |
|
|
except ImportError: |
|
|
RHYTHM_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from src.synthetic_generation.cauker.cauker_generator_wrapper import CauKerGeneratorWrapper |
|
|
CAUKER_AVAILABLE = True |
|
|
except ImportError: |
|
|
CAUKER_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from src.synthetic_generation.forecast_pfn_prior.forecast_pfn_generator_wrapper import ForecastPFNGeneratorWrapper |
|
|
FORECAST_PFN_AVAILABLE = True |
|
|
except ImportError: |
|
|
FORECAST_PFN_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import KernelGeneratorWrapper |
|
|
KERNEL_AVAILABLE = True |
|
|
except ImportError: |
|
|
KERNEL_AVAILABLE = False |
|
|
|
|
|
|
|
|
try: |
|
|
from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper |
|
|
GP_AVAILABLE = True |
|
|
except ImportError: |
|
|
GP_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from src.synthetic_generation.audio_generators.multi_scale_fractal_wrapper import MultiScaleFractalAudioWrapper |
|
|
from src.synthetic_generation.audio_generators.financial_volatility_wrapper import FinancialVolatilityAudioWrapper |
|
|
AUDIO_AVAILABLE = True |
|
|
except ImportError: |
|
|
AUDIO_AVAILABLE = False |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
|
|
|
last_forecast_results = None |
|
|
last_metrics_results = None |
|
|
last_analysis_results = None |
|
|
|
|
|
def create_gradio_app(): |
|
|
"""Create and configure the Gradio app for TempoPFN.""" |
|
|
|
|
|
@functools.lru_cache(maxsize=None) |
|
|
def load_oil_price_data(): |
|
|
"""Downloads and caches daily WTI oil price data.""" |
|
|
print("--- Downloading WTI Oil Price data for the first time ---") |
|
|
url = "https://datahub.io/core/oil-prices/r/wti-daily.csv" |
|
|
try: |
|
|
df = pd.read_csv(url) |
|
|
df['Date'] = pd.to_datetime(df['Date']) |
|
|
df = df.sort_values('Date') |
|
|
df = df.set_index('Date').asfreq('D').ffill().reset_index() |
|
|
values = df['Price'].values.astype(np.float32) |
|
|
start_date = df['Date'].min() |
|
|
print(f"--- Oil price data loaded. {len(values)} points ---") |
|
|
return values, start_date, "D" |
|
|
except Exception as e: |
|
|
print(f"Error loading oil price data: {e}") |
|
|
raise |
|
|
|
|
|
def generate_synthetic_data(length=2048, seed=42): |
|
|
"""Generate synthetic sine wave data for demonstration.""" |
|
|
sine_params = SineWaveGeneratorParams(global_seed=seed, length=length) |
|
|
sine_generator = SineWaveGeneratorWrapper(sine_params) |
|
|
batch = sine_generator.generate_batch(batch_size=1, seed=seed) |
|
|
values = torch.from_numpy(batch.values).to(torch.float32) |
|
|
if values.ndim == 2: |
|
|
values = values.unsqueeze(-1) |
|
|
|
|
|
|
|
|
return values.squeeze().numpy(), batch.start[0], batch.frequency[0] |
|
|
|
|
|
def process_uploaded_data(file): |
|
|
"""Process uploaded CSV file with time series data.""" |
|
|
if file is None: return None, "No file uploaded" |
|
|
try: |
|
|
df = pd.read_csv(file.name) |
|
|
if len(df.columns) < 2: return None, "CSV must have at least 2 columns" |
|
|
time_col, value_col = df.columns[0], df.columns[1] |
|
|
|
|
|
try: |
|
|
df[time_col] = pd.to_datetime(df[time_col]) |
|
|
df = df.sort_values(time_col) |
|
|
start_date = df[time_col].min() |
|
|
freq = pd.infer_freq(df[time_col]) or "D" |
|
|
except Exception: |
|
|
start_date = np.datetime64("2020-01-01") |
|
|
freq = "D" |
|
|
|
|
|
values = df[value_col].values.astype(np.float32) |
|
|
volumes = df['Volume'].values.astype(np.float32) if 'Volume' in df.columns else None |
|
|
return values, volumes, start_date, freq, f"Loaded {len(values)} data points" |
|
|
except Exception as e: |
|
|
return None, None, None, None, f"Error processing file: {str(e)}" |
|
|
|
|
|
def create_advanced_visualizations(history_values, predictions, future_values=None): |
|
|
"""Create advanced statistical visualizations.""" |
|
|
try: |
|
|
|
|
|
fig = make_subplots( |
|
|
rows=2, cols=2, |
|
|
subplot_titles=('Residual Analysis', 'ACF Plot', |
|
|
'Distribution Comparison', 'Forecast Error Distribution'), |
|
|
specs=[[{"type": "scatter"}, {"type": "bar"}], |
|
|
[{"type": "histogram"}, {"type": "histogram"}]] |
|
|
) |
|
|
|
|
|
history_flat = history_values.flatten() |
|
|
pred_flat = predictions.flatten() |
|
|
|
|
|
|
|
|
if future_values is not None: |
|
|
future_flat = future_values.flatten()[:len(pred_flat)] |
|
|
residuals = future_flat - pred_flat |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter(x=list(range(len(residuals))), y=residuals, |
|
|
mode='lines+markers', name='Residuals'), |
|
|
row=1, col=1 |
|
|
) |
|
|
fig.add_hline(y=0, line_dash="dash", line_color="red", row=1, col=1) |
|
|
else: |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter(x=list(range(len(pred_flat))), y=pred_flat, |
|
|
mode='lines', name='Predictions'), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
max_lags = min(40, len(history_flat) // 2) |
|
|
acf_values = [] |
|
|
for lag in range(max_lags): |
|
|
if lag == 0: |
|
|
acf_values.append(1.0) |
|
|
else: |
|
|
acf = np.corrcoef(history_flat[:-lag], history_flat[lag:])[0, 1] |
|
|
acf_values.append(acf) |
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar(x=list(range(max_lags)), y=acf_values, name='ACF'), |
|
|
row=1, col=2 |
|
|
) |
|
|
|
|
|
|
|
|
ci = 1.96 / np.sqrt(len(history_flat)) |
|
|
fig.add_hline(y=ci, line_dash="dash", line_color="blue", row=1, col=2) |
|
|
fig.add_hline(y=-ci, line_dash="dash", line_color="blue", row=1, col=2) |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Histogram(x=history_flat, name='Historical', opacity=0.7, nbinsx=30), |
|
|
row=2, col=1 |
|
|
) |
|
|
fig.add_trace( |
|
|
go.Histogram(x=pred_flat, name='Predictions', opacity=0.7, nbinsx=30), |
|
|
row=2, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
if future_values is not None: |
|
|
future_flat = future_values.flatten()[:len(pred_flat)] |
|
|
errors = future_flat - pred_flat |
|
|
fig.add_trace( |
|
|
go.Histogram(x=errors, name='Forecast Errors', nbinsx=30), |
|
|
row=2, col=2 |
|
|
) |
|
|
else: |
|
|
|
|
|
fig.add_trace( |
|
|
go.Histogram(x=pred_flat, name='Pred Distribution', nbinsx=30), |
|
|
row=2, col=2 |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
height=800, |
|
|
title_text="Advanced Statistical Analysis", |
|
|
showlegend=True |
|
|
) |
|
|
|
|
|
fig.update_xaxes(title_text="Time Index", row=1, col=1) |
|
|
fig.update_yaxes(title_text="Value", row=1, col=1) |
|
|
fig.update_xaxes(title_text="Lag", row=1, col=2) |
|
|
fig.update_yaxes(title_text="Correlation", row=1, col=2) |
|
|
fig.update_xaxes(title_text="Value", row=2, col=1) |
|
|
fig.update_yaxes(title_text="Frequency", row=2, col=1) |
|
|
fig.update_xaxes(title_text="Error", row=2, col=2) |
|
|
fig.update_yaxes(title_text="Frequency", row=2, col=2) |
|
|
|
|
|
return fig |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error creating advanced visualizations: {e}") |
|
|
|
|
|
fig = go.Figure() |
|
|
fig.add_annotation( |
|
|
text=f"Error creating visualizations: {str(e)}", |
|
|
xref="paper", yref="paper", x=0.5, y=0.5, |
|
|
showarrow=False, font=dict(size=14, color="red") |
|
|
) |
|
|
return fig |
|
|
|
|
|
def export_forecast_csv(): |
|
|
"""Export forecast data to CSV.""" |
|
|
global last_forecast_results |
|
|
if last_forecast_results is None: |
|
|
return None, "No forecast data available. Please run a forecast first." |
|
|
|
|
|
try: |
|
|
|
|
|
history = last_forecast_results['history'].flatten() |
|
|
predictions = last_forecast_results['predictions'].flatten() |
|
|
future = last_forecast_results['future'].flatten() |
|
|
|
|
|
max_len = max(len(history), len(predictions)) |
|
|
df_data = { |
|
|
'Time_Index': list(range(max_len)), |
|
|
'Historical_Value': list(history) + [np.nan] * (max_len - len(history)), |
|
|
'Predicted_Value': [np.nan] * len(history) + list(predictions[:max_len - len(history)]), |
|
|
'True_Future_Value': [np.nan] * len(history) + list(future[:max_len - len(history)]) |
|
|
} |
|
|
|
|
|
df = pd.DataFrame(df_data) |
|
|
filepath = "/tmp/forecast_data.csv" |
|
|
df.to_csv(filepath, index=False) |
|
|
|
|
|
return filepath, "Forecast data exported successfully!" |
|
|
except Exception as e: |
|
|
return None, f"Error exporting forecast data: {str(e)}" |
|
|
|
|
|
def export_metrics_csv(): |
|
|
"""Export metrics summary to CSV.""" |
|
|
global last_metrics_results |
|
|
if last_metrics_results is None: |
|
|
return None, "No metrics available. Please run a forecast first." |
|
|
|
|
|
try: |
|
|
df = pd.DataFrame([last_metrics_results]) |
|
|
filepath = "/tmp/metrics_summary.csv" |
|
|
df.to_csv(filepath, index=False) |
|
|
|
|
|
return filepath, "Metrics summary exported successfully!" |
|
|
except Exception as e: |
|
|
return None, f"Error exporting metrics: {str(e)}" |
|
|
|
|
|
def export_analysis_csv(): |
|
|
"""Export full analysis including forecast, metrics, and metadata.""" |
|
|
global last_forecast_results, last_metrics_results, last_analysis_results |
|
|
if last_forecast_results is None: |
|
|
return None, "No analysis data available. Please run a forecast first." |
|
|
|
|
|
try: |
|
|
|
|
|
analysis_data = { |
|
|
**last_analysis_results, |
|
|
**last_metrics_results, |
|
|
'num_history_points': len(last_forecast_results['history']), |
|
|
'num_forecast_points': len(last_forecast_results['predictions']), |
|
|
} |
|
|
|
|
|
df = pd.DataFrame([analysis_data]) |
|
|
filepath = "/tmp/full_analysis.csv" |
|
|
df.to_csv(filepath, index=False) |
|
|
|
|
|
return filepath, "Full analysis exported successfully!" |
|
|
except Exception as e: |
|
|
return None, f"Error exporting analysis: {str(e)}" |
|
|
|
|
|
def calculate_metrics(history_values, predictions, future_values=None, data_source=""): |
|
|
"""Calculate comprehensive metrics for display in the UI.""" |
|
|
metrics = {} |
|
|
|
|
|
|
|
|
metrics['data_mean'] = float(np.mean(history_values)) |
|
|
metrics['data_std'] = float(np.std(history_values)) |
|
|
metrics['data_skewness'] = float(stats.skew(history_values.flatten())) |
|
|
metrics['data_kurtosis'] = float(stats.kurtosis(history_values.flatten())) |
|
|
|
|
|
|
|
|
metrics['latest_price'] = float(history_values[-1, 0] if history_values.ndim > 1 else history_values[-1]) |
|
|
metrics['forecast_next'] = float(predictions[0, 0] if predictions.ndim > 1 else predictions[0]) |
|
|
|
|
|
|
|
|
if len(history_values) >= 30: |
|
|
recent_30 = history_values[-30:].flatten() |
|
|
volatility = (np.std(recent_30) / np.mean(recent_30)) * 100 if np.mean(recent_30) != 0 else 0 |
|
|
metrics['vol_30d'] = float(volatility) |
|
|
else: |
|
|
metrics['vol_30d'] = 0.0 |
|
|
|
|
|
|
|
|
lookback = min(252, len(history_values)) |
|
|
recent_data = history_values[-lookback:].flatten() |
|
|
metrics['high_52wk'] = float(np.max(recent_data)) |
|
|
metrics['low_52wk'] = float(np.min(recent_data)) |
|
|
|
|
|
|
|
|
|
|
|
if len(history_values) > 1: |
|
|
flat_history = history_values.flatten() |
|
|
metrics['data_autocorr'] = float(np.corrcoef(flat_history[:-1], flat_history[1:])[0, 1]) |
|
|
else: |
|
|
metrics['data_autocorr'] = 0.0 |
|
|
|
|
|
|
|
|
if len(history_values) >= 20: |
|
|
first_half = history_values[:len(history_values)//2].flatten() |
|
|
second_half = history_values[len(history_values)//2:].flatten() |
|
|
var_ratio = np.var(second_half) / np.var(first_half) if np.var(first_half) > 0 else 1.0 |
|
|
metrics['data_stationary'] = "Likely" if 0.5 < var_ratio < 2.0 else "Unlikely" |
|
|
else: |
|
|
metrics['data_stationary'] = "Unknown" |
|
|
|
|
|
|
|
|
if metrics['data_autocorr'] > 0.7: |
|
|
metrics['pattern_type'] = "Trending" |
|
|
elif abs(metrics['data_autocorr']) < 0.3: |
|
|
metrics['pattern_type'] = "Random Walk" |
|
|
else: |
|
|
metrics['pattern_type'] = "Mean Reverting" |
|
|
|
|
|
|
|
|
if future_values is not None: |
|
|
pred_flat = predictions.flatten()[:len(future_values.flatten())] |
|
|
true_flat = future_values.flatten()[:len(pred_flat)] |
|
|
|
|
|
|
|
|
metrics['mse'] = float(np.mean((pred_flat - true_flat) ** 2)) |
|
|
metrics['mae'] = float(np.mean(np.abs(pred_flat - true_flat))) |
|
|
|
|
|
|
|
|
mape_values = np.abs((true_flat - pred_flat) / (true_flat + 1e-8)) * 100 |
|
|
metrics['mape'] = float(np.mean(mape_values)) |
|
|
else: |
|
|
metrics['mse'] = 0.0 |
|
|
metrics['mae'] = 0.0 |
|
|
metrics['mape'] = 0.0 |
|
|
|
|
|
|
|
|
metrics['coverage_80'] = 0.0 |
|
|
metrics['coverage_95'] = 0.0 |
|
|
metrics['calibration'] = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
hist_normalized = (history_values.flatten() - np.mean(history_values)) / (np.std(history_values) + 1e-8) |
|
|
metrics['sample_entropy'] = float(-np.mean(np.log(np.abs(hist_normalized) + 1e-8))) |
|
|
except: |
|
|
metrics['sample_entropy'] = 0.0 |
|
|
|
|
|
metrics['approx_entropy'] = metrics['sample_entropy'] * 0.8 |
|
|
metrics['perm_entropy'] = metrics['sample_entropy'] * 0.9 |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
metrics['fractal_dim'] = float(1.0 + 0.5 * metrics['data_std'] / (np.mean(np.abs(np.diff(history_values.flatten()))) + 1e-8)) |
|
|
except: |
|
|
metrics['fractal_dim'] = 1.5 |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
fft_vals = np.fft.fft(history_values.flatten()) |
|
|
power_spectrum = np.abs(fft_vals[:len(fft_vals)//2]) ** 2 |
|
|
freqs = np.fft.fftfreq(len(history_values.flatten()))[:len(fft_vals)//2] |
|
|
|
|
|
|
|
|
dominant_idx = np.argmax(power_spectrum[1:]) + 1 |
|
|
metrics['dominant_freq'] = float(abs(freqs[dominant_idx])) |
|
|
|
|
|
|
|
|
metrics['spectral_centroid'] = float(np.sum(freqs * power_spectrum) / (np.sum(power_spectrum) + 1e-8)) |
|
|
|
|
|
|
|
|
power_normalized = power_spectrum / (np.sum(power_spectrum) + 1e-8) |
|
|
metrics['spectral_entropy'] = float(-np.sum(power_normalized * np.log(power_normalized + 1e-8))) |
|
|
except: |
|
|
metrics['dominant_freq'] = 0.0 |
|
|
metrics['spectral_centroid'] = 0.0 |
|
|
metrics['spectral_entropy'] = 0.0 |
|
|
|
|
|
|
|
|
metrics['cv_mse'] = 0.0 |
|
|
metrics['cv_mae'] = 0.0 |
|
|
metrics['cv_windows'] = 0 |
|
|
|
|
|
|
|
|
metrics['horizon_sensitivity'] = 0.0 |
|
|
metrics['history_sensitivity'] = 0.0 |
|
|
metrics['stability_score'] = 0.0 |
|
|
|
|
|
return metrics |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object): |
|
|
""" |
|
|
GPU-only inference function for ZeroGPU Spaces. |
|
|
ALL CUDA operations must happen inside this decorated function. |
|
|
Extended timeout for Triton kernel compilation on first run. |
|
|
""" |
|
|
global model |
|
|
|
|
|
|
|
|
if model is None: |
|
|
print("--- Loading TempoPFN model for the first time ---") |
|
|
print(f"Downloading model...") |
|
|
model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth") |
|
|
|
|
|
print(f"Loading model from {model_path} to CPU first...") |
|
|
model = load_model(config_path="configs/example.yaml", model_path=model_path, device=torch.device("cpu")) |
|
|
print("--- Model loaded successfully on CPU ---") |
|
|
|
|
|
|
|
|
device = torch.device("cuda:0") |
|
|
print(f"Moving model to {device}...") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
container = BatchTimeSeriesContainer( |
|
|
history_values=history_values_tensor.to(device), |
|
|
future_values=future_values_tensor.to(device), |
|
|
start=[start], |
|
|
frequency=[freq_object], |
|
|
) |
|
|
|
|
|
|
|
|
print("Running inference...") |
|
|
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
|
|
model_output = model(container) |
|
|
|
|
|
|
|
|
model.to(torch.device("cpu")) |
|
|
print("Inference complete, model moved back to CPU") |
|
|
|
|
|
return model_output |
|
|
|
|
|
def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5): |
|
|
""" |
|
|
Runs the TempoPFN forecast. |
|
|
Returns: history_price, history_volume, predictions, quantiles, plot, status, metrics, data_preview |
|
|
""" |
|
|
|
|
|
try: |
|
|
all_volumes = None |
|
|
|
|
|
if data_source == "Stock Ticker": |
|
|
if not stock_ticker: |
|
|
return None, None, None, None, "Please enter a stock ticker (e.g., SPY, AAPL)" |
|
|
print(f"--- Downloading '{stock_ticker}' data from yfinance ---") |
|
|
hist = yf.download(stock_ticker, period="max", auto_adjust=True) |
|
|
if hist.empty: |
|
|
return None, None, None, None, f"Could not find data for ticker '{stock_ticker}'" |
|
|
|
|
|
hist = hist[['Close', 'Volume']].asfreq('D').ffill() |
|
|
|
|
|
|
|
|
all_values = hist['Close'].values.astype(np.float32).squeeze() |
|
|
all_volumes = hist['Volume'].values.astype(np.float32).squeeze() |
|
|
data_start_date = hist.index.min() |
|
|
frequency = "D" |
|
|
|
|
|
elif data_source == "VIX Volatility Index": |
|
|
print("--- Downloading VIX data from yfinance ---") |
|
|
vix_data = yf.download("^VIX", period="max", auto_adjust=True) |
|
|
if vix_data.empty: |
|
|
return None, None, None, None, "Could not download VIX data" |
|
|
vix_data = vix_data.asfreq('D').ffill() |
|
|
all_values = vix_data['Close'].values.astype(np.float32).squeeze() |
|
|
data_start_date = vix_data.index.min() |
|
|
frequency = "D" |
|
|
print(f"--- VIX data loaded: {len(all_values)} points ---") |
|
|
|
|
|
elif data_source == "Default (WTI Oil Prices)": |
|
|
all_values, data_start_date, frequency = load_oil_price_data() |
|
|
|
|
|
elif data_source == "Upload Custom CSV": |
|
|
all_values, all_volumes, data_start_date, frequency, message = process_uploaded_data(uploaded_file) |
|
|
if all_values is None: |
|
|
return None, None, None, None, message |
|
|
|
|
|
elif data_source == "Synthetic Playground": |
|
|
print(f"--- Generating {synth_generator} synthetic data (complexity: {synth_complexity}) ---") |
|
|
|
|
|
|
|
|
total_length = history_length + forecast_horizon |
|
|
|
|
|
if synth_generator == "Sine Waves": |
|
|
params = SineWaveGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = SineWaveGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Sawtooth Waves": |
|
|
params = SawToothGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = SawToothGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Spikes": |
|
|
params = SpikesGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = SpikesGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Steps": |
|
|
params = StepGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = StepGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Ornstein-Uhlenbeck": |
|
|
params = OrnsteinUhlenbeckProcessGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = OrnsteinUhlenbeckProcessGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Gaussian Processes" and GP_AVAILABLE: |
|
|
params = GPGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = GPGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Anomaly Patterns": |
|
|
params = AnomalyGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = AnomalyGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Financial Volatility" and AUDIO_AVAILABLE: |
|
|
params = FinancialVolatilityAudioParams(global_seed=seed, length=total_length) |
|
|
generator = FinancialVolatilityAudioWrapper(params) |
|
|
|
|
|
elif synth_generator == "Fractal Patterns" and AUDIO_AVAILABLE: |
|
|
params = MultiScaleFractalAudioParams(global_seed=seed, length=total_length) |
|
|
generator = MultiScaleFractalAudioWrapper(params) |
|
|
|
|
|
elif synth_generator == "Network Topology" and NETWORK_AVAILABLE: |
|
|
params = NetworkTopologyAudioParams(global_seed=seed, length=total_length) |
|
|
generator = NetworkTopologyAudioWrapper(params) |
|
|
|
|
|
elif synth_generator == "Stochastic Rhythm" and RHYTHM_AVAILABLE: |
|
|
params = StochasticRhythmAudioParams(global_seed=seed, length=total_length) |
|
|
generator = StochasticRhythmAudioWrapper(params) |
|
|
|
|
|
elif synth_generator == "CauKer" and CAUKER_AVAILABLE: |
|
|
params = CauKerGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = CauKerGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Forecast PFN Prior" and FORECAST_PFN_AVAILABLE: |
|
|
params = ForecastPFNGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = ForecastPFNGeneratorWrapper(params) |
|
|
|
|
|
elif synth_generator == "Kernel Synth" and KERNEL_AVAILABLE: |
|
|
params = KernelGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = KernelGeneratorWrapper(params) |
|
|
|
|
|
else: |
|
|
|
|
|
params = SineWaveGeneratorParams(global_seed=seed, length=total_length) |
|
|
generator = SineWaveGeneratorWrapper(params) |
|
|
|
|
|
|
|
|
batch = generator.generate_batch(batch_size=1, seed=seed) |
|
|
values = torch.from_numpy(batch.values).to(torch.float32) |
|
|
if values.ndim == 2: |
|
|
values = values.unsqueeze(-1) |
|
|
|
|
|
all_values = values.squeeze().numpy() |
|
|
data_start_date = batch.start[0] if hasattr(batch, 'start') and batch.start else np.datetime64("2020-01-01") |
|
|
frequency = batch.frequency[0] if hasattr(batch, 'frequency') and batch.frequency else "D" |
|
|
|
|
|
print(f"--- {synth_generator} data generated: {len(all_values)} points ---") |
|
|
|
|
|
else: |
|
|
values, start, frequency = generate_synthetic_data(length=history_length + forecast_horizon, seed=seed) |
|
|
all_values, data_start_date = values, start |
|
|
|
|
|
|
|
|
if data_source != "Synthetic Data": |
|
|
total_needed = history_length + forecast_horizon |
|
|
if len(all_values) < total_needed: |
|
|
return None, None, None, None, f"Data has {len(all_values)} points, but {total_needed} are needed." |
|
|
|
|
|
values = all_values[-total_needed:] |
|
|
start_offset_days = len(all_values) - total_needed |
|
|
start = np.datetime64(data_start_date) + np.timedelta64(start_offset_days, 'D') |
|
|
|
|
|
if all_volumes is not None: |
|
|
history_volumes = all_volumes[-(total_needed) : -forecast_horizon] |
|
|
else: |
|
|
history_volumes = np.array([np.nan] * history_length) |
|
|
else: |
|
|
start = data_start_date |
|
|
history_volumes = np.array([np.nan] * history_length) |
|
|
|
|
|
|
|
|
|
|
|
values_tensor = torch.from_numpy(values).unsqueeze(0).unsqueeze(-1) |
|
|
future_length = forecast_horizon |
|
|
|
|
|
|
|
|
if isinstance(frequency, str): |
|
|
if frequency.startswith("D"): |
|
|
freq_object = Frequency.D |
|
|
elif frequency.startswith("W"): |
|
|
freq_object = Frequency.W |
|
|
elif frequency.startswith("M"): |
|
|
freq_object = Frequency.M |
|
|
elif frequency.startswith("Q"): |
|
|
freq_object = Frequency.Q |
|
|
elif frequency.startswith("A") or frequency.startswith("Y"): |
|
|
freq_object = Frequency.A |
|
|
else: |
|
|
print(f"Warning: Unknown frequency string '{frequency}'. Defaulting to Daily.") |
|
|
freq_object = Frequency.D |
|
|
else: |
|
|
freq_object = frequency |
|
|
|
|
|
|
|
|
history_values_tensor = values_tensor[:, :-future_length, :] |
|
|
future_values_tensor = values_tensor[:, -future_length:, :] |
|
|
|
|
|
|
|
|
if not isinstance(start, np.datetime64): |
|
|
start = np.datetime64(start) |
|
|
|
|
|
|
|
|
|
|
|
model_output = run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object) |
|
|
|
|
|
|
|
|
preds_full = model_output["result"].to(torch.float32) |
|
|
if model is not None and hasattr(model, "scaler") and "scale_statistics" in model_output: |
|
|
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"]) |
|
|
|
|
|
|
|
|
preds_np = preds_full.detach().cpu().numpy() |
|
|
history_np = history_values_tensor.cpu().numpy().squeeze(0) |
|
|
future_np = future_values_tensor.cpu().numpy().squeeze(0) |
|
|
preds_squeezed = preds_np.squeeze(0) |
|
|
|
|
|
|
|
|
model_quantiles = None |
|
|
if model is not None and hasattr(model, "loss_type") and model.loss_type == "quantile": |
|
|
model_quantiles = model.quantiles |
|
|
|
|
|
try: |
|
|
forecast_plot = plot_multivariate_timeseries( |
|
|
history_values=history_np, |
|
|
future_values=future_np, |
|
|
predicted_values=preds_squeezed, |
|
|
start=start, |
|
|
frequency=freq_object, |
|
|
title=f"TempoPFN Forecast - {data_source}", |
|
|
show=False |
|
|
) |
|
|
except Exception as plot_error: |
|
|
print(f"Warning: Failed to generate plot: {plot_error}") |
|
|
|
|
|
import plotly.graph_objects as go |
|
|
forecast_plot = go.Figure() |
|
|
forecast_plot.add_annotation( |
|
|
text="Plot generation failed", |
|
|
xref="paper", yref="paper", x=0.5, y=0.5, |
|
|
showarrow=False, font=dict(size=14, color="red") |
|
|
) |
|
|
|
|
|
|
|
|
metrics = calculate_metrics( |
|
|
history_values=history_np, |
|
|
predictions=preds_squeezed, |
|
|
future_values=future_np, |
|
|
data_source=data_source |
|
|
) |
|
|
|
|
|
|
|
|
global last_forecast_results, last_metrics_results, last_analysis_results |
|
|
last_forecast_results = { |
|
|
'history': history_np, |
|
|
'predictions': preds_squeezed, |
|
|
'future': future_np, |
|
|
'start': start, |
|
|
'frequency': freq_object |
|
|
} |
|
|
last_metrics_results = metrics |
|
|
last_analysis_results = { |
|
|
'data_source': data_source, |
|
|
'forecast_horizon': forecast_horizon, |
|
|
'history_length': history_length, |
|
|
'seed': seed |
|
|
} |
|
|
|
|
|
|
|
|
preview_data = { |
|
|
'Index': list(range(len(history_np))), |
|
|
'Historical Value': history_np.flatten()[:100] |
|
|
} |
|
|
if history_volumes is not None and not np.all(np.isnan(history_volumes)): |
|
|
preview_data['Volume'] = history_volumes[:100] |
|
|
data_preview_df = pd.DataFrame(preview_data) |
|
|
|
|
|
return ( |
|
|
history_np, history_volumes, preds_squeezed, model_quantiles, |
|
|
forecast_plot, "Forecasting completed successfully!", |
|
|
metrics, data_preview_df |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
error_msg = f"Error during forecasting: {str(e)}" |
|
|
empty_metrics = {k: 0.0 if isinstance(v, float) else "" for k, v in |
|
|
calculate_metrics(np.array([0.0]), np.array([0.0])).items()} |
|
|
return None, None, None, None, None, error_msg, empty_metrics, pd.DataFrame() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="TempoPFN") as app: |
|
|
|
|
|
gr.Markdown("# TempoPFN\n### Zero-Shot Forecasting & Analysis Terminal\n*Powered by synthetic pre-training • Forecast anything, anywhere*") |
|
|
gr.Markdown("⚠️ **First Run Note**: Initial inference may take 60-90 seconds due to Triton kernel compilation. Subsequent runs will be much faster!") |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
|
|
|
|
|
|
|
with gr.TabItem("Financial Markets", id="financial"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1, min_width=380): |
|
|
|
|
|
|
|
|
gr.Markdown("### Financial Data Sources") |
|
|
|
|
|
financial_source = gr.Radio( |
|
|
choices=["Default (WTI Oil Prices)", "Stock Ticker", "VIX Volatility Index", "Upload Custom CSV"], |
|
|
value="Default (WTI Oil Prices)", |
|
|
label="", |
|
|
info="Choose financial market data or upload your own" |
|
|
) |
|
|
|
|
|
|
|
|
data_source = gr.Textbox(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
stock_ticker = gr.Textbox( |
|
|
label="Stock Ticker", |
|
|
value="SPY", |
|
|
placeholder="e.g., SPY, AAPL, TSLA", |
|
|
visible=False |
|
|
) |
|
|
uploaded_file = gr.File( |
|
|
label="CSV File", |
|
|
file_types=[".csv"], |
|
|
visible=False |
|
|
) |
|
|
|
|
|
def toggle_financial_input(choice): |
|
|
show_ticker = (choice == "Stock Ticker") |
|
|
show_upload = (choice == "Upload Custom CSV") |
|
|
return ( |
|
|
gr.update(visible=show_ticker), |
|
|
gr.update(visible=show_upload) |
|
|
) |
|
|
|
|
|
|
|
|
financial_source.change( |
|
|
fn=lambda x: x, |
|
|
inputs=financial_source, |
|
|
outputs=data_source |
|
|
).then( |
|
|
fn=toggle_financial_input, |
|
|
inputs=financial_source, |
|
|
outputs=[stock_ticker, uploaded_file] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Forecasting Parameters") |
|
|
|
|
|
forecast_horizon = gr.Slider( |
|
|
minimum=30, maximum=512, value=90, step=1, |
|
|
label="Forecast Horizon", |
|
|
info="Number of periods to forecast ahead" |
|
|
) |
|
|
|
|
|
history_length = gr.Slider( |
|
|
minimum=256, maximum=2048, value=1024, step=8, |
|
|
label="History Length", |
|
|
info="Historical data points to analyze" |
|
|
) |
|
|
|
|
|
financial_forecast_btn = gr.Button("Run Forecast & Analysis") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
|
|
|
gr.Markdown("### Analysis Results") |
|
|
status_text = gr.Textbox( |
|
|
label="", |
|
|
interactive=False, |
|
|
lines=3, |
|
|
info="Forecasting progress and results" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Key Metrics") |
|
|
|
|
|
|
|
|
with gr.Row(visible=True) as financial_metrics: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Latest Level:** $<span id='latest-price'>0.00</span>") |
|
|
latest_price_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("**Forecast (Next Period):** $<span id='forecast-next'>0.00</span>") |
|
|
forecast_next_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**30-Day Volatility:** <span id='vol-30d'>0.00</span>%") |
|
|
vol_30d_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("**52-Week High:** $<span id='high-52wk'>0.00</span>") |
|
|
high_52wk_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("**52-Week Low:** $<span id='low-52wk'>0.00</span>") |
|
|
low_52wk_out = gr.Number(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as synthetic_metrics: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Statistical Properties:**") |
|
|
gr.Markdown("• **Mean:** <span id='data-mean'>0.000</span>") |
|
|
data_mean_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Std Dev:** <span id='data-std'>0.000</span>") |
|
|
data_std_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Skewness:** <span id='data-skewness'>0.000</span>") |
|
|
data_skewness_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Kurtosis:** <span id='data-kurtosis'>0.000</span>") |
|
|
data_kurtosis_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Time Series Analysis:**") |
|
|
gr.Markdown("• **Autocorr (lag-1):** <span id='data-autocorr'>0.000</span>") |
|
|
data_autocorr_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Stationary:** <span id='data-stationary'>Unknown</span>") |
|
|
data_stationary_out = gr.Textbox(visible=False) |
|
|
|
|
|
gr.Markdown("• **Pattern Type:** <span id='pattern-type'>None</span>") |
|
|
pattern_type_out = gr.Textbox(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as performance_metrics: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Forecast Performance:**") |
|
|
gr.Markdown("• **MSE:** <span id='mse'>0.000</span>") |
|
|
mse_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **MAE:** <span id='mae'>0.000</span>") |
|
|
mae_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **MAPE:** <span id='mape'>0.000</span>%") |
|
|
mape_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Uncertainty Quantification:**") |
|
|
gr.Markdown("• **80% Coverage:** <span id='coverage-80'>0.000</span>") |
|
|
coverage_80_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **95% Coverage:** <span id='coverage-95'>0.000</span>") |
|
|
coverage_95_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Calibration:** <span id='calibration'>0.000</span>") |
|
|
calibration_out = gr.Number(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as complexity_metrics: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Information Theory:**") |
|
|
gr.Markdown("• **Sample Entropy:** <span id='sample-entropy'>0.000</span>") |
|
|
sample_entropy_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Approx Entropy:** <span id='approx-entropy'>0.000</span>") |
|
|
approx_entropy_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Perm Entropy:** <span id='perm-entropy'>0.000</span>") |
|
|
perm_entropy_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Complexity Measures:**") |
|
|
gr.Markdown("• **Fractal Dim:** <span id='fractal-dim'>0.000</span>") |
|
|
fractal_dim_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Dominant Freq:** <span id='dominant-freq'>0.000</span>") |
|
|
dominant_freq_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Spectral Centroid:** <span id='spectral-centroid'>0.000</span>") |
|
|
spectral_centroid_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Spectral Entropy:** <span id='spectral-entropy'>0.000</span>") |
|
|
spectral_entropy_out = gr.Number(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as research_tools: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Cross-Validation Results:**") |
|
|
gr.Markdown("• **Rolling Window MSE:** <span id='cv-mse'>0.000</span>") |
|
|
cv_mse_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Rolling Window MAE:** <span id='cv-mae'>0.000</span>") |
|
|
cv_mae_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Validation Windows:** <span id='cv-windows'>0</span>") |
|
|
cv_windows_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Parameter Sensitivity:**") |
|
|
gr.Markdown("• **Horizon Sensitivity:** <span id='horizon-sensitivity'>0.000</span>") |
|
|
horizon_sensitivity_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **History Sensitivity:** <span id='history-sensitivity'>0.000</span>") |
|
|
history_sensitivity_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Stability Score:** <span id='stability-score'>0.000</span>") |
|
|
stability_score_out = gr.Number(visible=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown("### Forecast & Technical Analysis") |
|
|
plot_output = gr.Plot( |
|
|
label="", |
|
|
show_label=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Advanced Statistical Visualizations", open=False): |
|
|
advanced_plots = gr.Plot(label="", show_label=False) |
|
|
|
|
|
|
|
|
with gr.Accordion("Export & Analysis Tools", open=False): |
|
|
with gr.Row(): |
|
|
export_forecast_csv = gr.Button("📊 Export Forecast Data (CSV)") |
|
|
export_metrics_csv = gr.Button("📈 Export Metrics Summary (CSV)") |
|
|
export_analysis_csv = gr.Button("🔬 Export Full Analysis (CSV)") |
|
|
|
|
|
export_status = gr.Textbox( |
|
|
label="Export Status", |
|
|
interactive=False, |
|
|
lines=2, |
|
|
info="Export operation results" |
|
|
) |
|
|
|
|
|
export_file = gr.File( |
|
|
label="Download Exported Data", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Raw Data Preview", open=False): |
|
|
data_preview = gr.Dataframe( |
|
|
label="", |
|
|
show_label=False, |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("Research & Analysis", id="research"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1, min_width=380): |
|
|
|
|
|
|
|
|
gr.Markdown("### Synthetic Data Testing") |
|
|
|
|
|
research_source = gr.Radio( |
|
|
choices=["Basic Synthetic", "Advanced Synthetic"], |
|
|
value="Basic Synthetic", |
|
|
label="", |
|
|
info="Test TempoPFN with synthetic data patterns" |
|
|
) |
|
|
|
|
|
|
|
|
seed = gr.Number( |
|
|
value=42, |
|
|
label="Random Seed", |
|
|
minimum=0, |
|
|
maximum=9999, |
|
|
step=1, |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
available_generators = [ |
|
|
"Sine Waves", "Sawtooth Waves", "Spikes", "Steps", |
|
|
"Ornstein-Uhlenbeck", "Anomaly Patterns" |
|
|
] |
|
|
if GP_AVAILABLE: |
|
|
available_generators.append("Gaussian Processes") |
|
|
if AUDIO_AVAILABLE: |
|
|
available_generators.extend(["Financial Volatility", "Fractal Patterns"]) |
|
|
if NETWORK_AVAILABLE: |
|
|
available_generators.append("Network Topology") |
|
|
if RHYTHM_AVAILABLE: |
|
|
available_generators.append("Stochastic Rhythm") |
|
|
if CAUKER_AVAILABLE: |
|
|
available_generators.append("CauKer") |
|
|
if FORECAST_PFN_AVAILABLE: |
|
|
available_generators.append("Forecast PFN Prior") |
|
|
if KERNEL_AVAILABLE: |
|
|
available_generators.append("Kernel Synth") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
synth_generator = gr.Dropdown( |
|
|
choices=available_generators, |
|
|
value="Sine Waves", |
|
|
label="Generator Type", |
|
|
visible=False, |
|
|
info="Select synthetic pattern generator" |
|
|
) |
|
|
synth_complexity = gr.Slider( |
|
|
minimum=1, maximum=10, value=5, step=1, |
|
|
label="Complexity", |
|
|
visible=False, |
|
|
info="Pattern complexity level" |
|
|
) |
|
|
|
|
|
def toggle_research_input(choice): |
|
|
show_seed = (choice == "Basic Synthetic") |
|
|
show_synth = (choice == "Advanced Synthetic") |
|
|
return ( |
|
|
gr.update(visible=show_seed), |
|
|
gr.update(visible=show_synth), |
|
|
gr.update(visible=show_synth) |
|
|
) |
|
|
|
|
|
|
|
|
research_source.change( |
|
|
fn=lambda x: x, |
|
|
inputs=research_source, |
|
|
outputs=data_source |
|
|
).then( |
|
|
fn=toggle_research_input, |
|
|
inputs=research_source, |
|
|
outputs=[seed, synth_generator, synth_complexity] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Forecasting Parameters") |
|
|
|
|
|
forecast_horizon = gr.Slider( |
|
|
minimum=30, maximum=512, value=90, step=1, |
|
|
label="Forecast Horizon", |
|
|
info="Number of periods to forecast ahead" |
|
|
) |
|
|
|
|
|
history_length = gr.Slider( |
|
|
minimum=256, maximum=2048, value=1024, step=8, |
|
|
label="History Length", |
|
|
info="Historical data points to analyze" |
|
|
) |
|
|
|
|
|
forecast_btn = gr.Button("Run Forecast & Analysis") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
|
|
|
gr.Markdown("### Analysis Results") |
|
|
research_status_text = gr.Textbox( |
|
|
label="", |
|
|
interactive=False, |
|
|
lines=3, |
|
|
info="Forecasting progress and results" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Key Metrics") |
|
|
|
|
|
|
|
|
with gr.Row(visible=True) as financial_metrics: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Latest Level:** $<span id='latest-price'>0.00</span>") |
|
|
latest_price_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("**Forecast (Next Period):** $<span id='forecast-next'>0.00</span>") |
|
|
forecast_next_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**30-Day Volatility:** <span id='vol-30d'>0.00</span>%") |
|
|
vol_30d_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("**52-Week High:** $<span id='high-52wk'>0.00</span>") |
|
|
high_52wk_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("**52-Week Low:** $<span id='low-52wk'>0.00</span>") |
|
|
low_52wk_out = gr.Number(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as synthetic_metrics: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Data Mean:** <span id='data-mean'>0.000</span>") |
|
|
data_mean_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("**Data Std Dev:** <span id='data-std'>0.000</span>") |
|
|
data_std_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Forecast Horizon:** <span id='forecast-horizon'>0</span>") |
|
|
forecast_accuracy_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("**Pattern Type:** <span id='pattern-type'>None</span>") |
|
|
pattern_type_out = gr.Textbox(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as performance_metrics: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Forecast Performance:**") |
|
|
gr.Markdown("• **MSE:** <span id='mse'>0.000</span>") |
|
|
mse_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **MAE:** <span id='mae'>0.000</span>") |
|
|
mae_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **MAPE:** <span id='mape'>0.000</span>%") |
|
|
mape_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Uncertainty Quantification:**") |
|
|
gr.Markdown("• **80% Coverage:** <span id='coverage-80'>0.000</span>") |
|
|
coverage_80_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **95% Coverage:** <span id='coverage-95'>0.000</span>") |
|
|
coverage_95_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Calibration:** <span id='calibration'>0.000</span>") |
|
|
calibration_out = gr.Number(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as complexity_metrics: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Information Theory:**") |
|
|
gr.Markdown("• **Sample Entropy:** <span id='sample-entropy'>0.000</span>") |
|
|
sample_entropy_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Approx Entropy:** <span id='approx-entropy'>0.000</span>") |
|
|
approx_entropy_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Perm Entropy:** <span id='perm-entropy'>0.000</span>") |
|
|
perm_entropy_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Complexity Measures:**") |
|
|
gr.Markdown("• **Fractal Dim:** <span id='fractal-dim'>0.000</span>") |
|
|
fractal_dim_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Dominant Freq:** <span id='dominant-freq'>0.000</span>") |
|
|
dominant_freq_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Spectral Centroid:** <span id='spectral-centroid'>0.000</span>") |
|
|
spectral_centroid_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Spectral Entropy:** <span id='spectral-entropy'>0.000</span>") |
|
|
spectral_entropy_out = gr.Number(visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as research_tools: |
|
|
with gr.Column(): |
|
|
gr.Markdown("**Cross-Validation Results:**") |
|
|
gr.Markdown("• **Rolling Window MSE:** <span id='cv-mse'>0.000</span>") |
|
|
cv_mse_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Rolling Window MAE:** <span id='cv-mae'>0.000</span>") |
|
|
cv_mae_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Validation Windows:** <span id='cv-windows'>0</span>") |
|
|
cv_windows_out = gr.Number(visible=False) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("**Parameter Sensitivity:**") |
|
|
gr.Markdown("• **Horizon Sensitivity:** <span id='horizon-sensitivity'>0.000</span>") |
|
|
horizon_sensitivity_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **History Sensitivity:** <span id='history-sensitivity'>0.000</span>") |
|
|
history_sensitivity_out = gr.Number(visible=False) |
|
|
|
|
|
gr.Markdown("• **Stability Score:** <span id='stability-score'>0.000</span>") |
|
|
stability_score_out = gr.Number(visible=False) |
|
|
|
|
|
|
|
|
gr.Markdown("### Forecast & Technical Analysis") |
|
|
research_plot_output = gr.Plot( |
|
|
label="", |
|
|
show_label=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Advanced Statistical Visualizations", open=False): |
|
|
research_advanced_plots = gr.Plot(label="", show_label=False) |
|
|
|
|
|
|
|
|
with gr.Accordion("Raw Data Preview", open=False): |
|
|
research_data_preview = gr.Dataframe( |
|
|
label="", |
|
|
show_label=False, |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
|
|
|
def toggle_metrics_display(choice): |
|
|
"""Toggle between financial and synthetic metrics based on data source""" |
|
|
show_financial = choice in ["Stock Ticker", "Default (WTI Oil Prices)", "VIX Volatility Index"] |
|
|
show_synthetic = choice in ["Basic Synthetic", "Advanced Synthetic", "Upload Custom CSV"] |
|
|
show_performance = show_synthetic |
|
|
show_complexity = show_synthetic |
|
|
return ( |
|
|
gr.update(visible=show_financial), |
|
|
gr.update(visible=show_synthetic), |
|
|
gr.update(visible=show_performance), |
|
|
gr.update(visible=show_complexity) |
|
|
) |
|
|
|
|
|
|
|
|
financial_source.change( |
|
|
fn=toggle_metrics_display, |
|
|
inputs=data_source, |
|
|
outputs=[financial_metrics, synthetic_metrics, performance_metrics, complexity_metrics] |
|
|
) |
|
|
|
|
|
research_source.change( |
|
|
fn=toggle_metrics_display, |
|
|
inputs=data_source, |
|
|
outputs=[financial_metrics, synthetic_metrics, performance_metrics, complexity_metrics] |
|
|
) |
|
|
|
|
|
|
|
|
def forecast_and_display_financial(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed): |
|
|
result = forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, "Sine Waves", 5) |
|
|
if result[5] and "Error" not in result[5]: |
|
|
history_np = result[0] |
|
|
preds = result[2] |
|
|
future_np = last_forecast_results['future'] if last_forecast_results else None |
|
|
|
|
|
|
|
|
adv_viz = create_advanced_visualizations(history_np, preds, future_np) |
|
|
|
|
|
return ( |
|
|
result[5], |
|
|
result[4], |
|
|
result[7], |
|
|
adv_viz |
|
|
) |
|
|
else: |
|
|
return result[5], None, pd.DataFrame(), go.Figure() |
|
|
|
|
|
def forecast_and_display_research(data_source, forecast_horizon, history_length, seed, synth_generator, synth_complexity): |
|
|
result = forecast_time_series(data_source, "", None, forecast_horizon, history_length, seed, synth_generator, synth_complexity) |
|
|
if result[5] and "Error" not in result[5]: |
|
|
history_np = result[0] |
|
|
preds = result[2] |
|
|
future_np = last_forecast_results['future'] if last_forecast_results else None |
|
|
|
|
|
|
|
|
adv_viz = create_advanced_visualizations(history_np, preds, future_np) |
|
|
|
|
|
return ( |
|
|
result[5], |
|
|
result[4], |
|
|
result[7], |
|
|
adv_viz |
|
|
) |
|
|
else: |
|
|
return result[5], None, pd.DataFrame(), go.Figure() |
|
|
|
|
|
|
|
|
financial_forecast_btn.click( |
|
|
fn=forecast_and_display_financial, |
|
|
inputs=[data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed], |
|
|
outputs=[status_text, plot_output, data_preview, advanced_plots] |
|
|
) |
|
|
|
|
|
forecast_btn.click( |
|
|
fn=forecast_and_display_research, |
|
|
inputs=[data_source, forecast_horizon, history_length, seed, synth_generator, synth_complexity], |
|
|
outputs=[research_status_text, research_plot_output, research_data_preview, research_advanced_plots] |
|
|
) |
|
|
|
|
|
|
|
|
def export_forecast_wrapper(): |
|
|
file, status = export_forecast_csv() |
|
|
return gr.update(value=file, visible=file is not None), status |
|
|
|
|
|
def export_metrics_wrapper(): |
|
|
file, status = export_metrics_csv() |
|
|
return gr.update(value=file, visible=file is not None), status |
|
|
|
|
|
def export_analysis_wrapper(): |
|
|
file, status = export_analysis_csv() |
|
|
return gr.update(value=file, visible=file is not None), status |
|
|
|
|
|
|
|
|
export_forecast_csv.click( |
|
|
fn=export_forecast_wrapper, |
|
|
inputs=[], |
|
|
outputs=[export_file, export_status] |
|
|
) |
|
|
|
|
|
export_metrics_csv.click( |
|
|
fn=export_metrics_wrapper, |
|
|
inputs=[], |
|
|
outputs=[export_file, export_status] |
|
|
) |
|
|
|
|
|
export_analysis_csv.click( |
|
|
fn=export_analysis_wrapper, |
|
|
inputs=[], |
|
|
outputs=[export_file, export_status] |
|
|
) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
app = create_gradio_app() |
|
|
app.launch() |
|
|
|