tempoPFN / app.py
altpuppet
Fix ZeroGPU timeout issue - extend duration and optimize model loading
6cc66f0
import gradio as gr
import numpy as np
import pandas as pd
import torch
import yaml
from huggingface_hub import hf_hub_download
import spaces
import traceback
import functools
import yfinance as yf
import pandas_ta as ta
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats
# --- All your src imports ---
from examples.utils import load_model
from src.plotting.plot_timeseries import plot_multivariate_timeseries
from src.data.containers import BatchTimeSeriesContainer, Frequency
from src.synthetic_generation.generator_params import (
SineWaveGeneratorParams, GPGeneratorParams, AnomalyGeneratorParams,
MultiScaleFractalAudioParams, FinancialVolatilityAudioParams,
SawToothGeneratorParams, SpikesGeneratorParams, StepGeneratorParams,
OrnsteinUhlenbeckProcessGeneratorParams, NetworkTopologyAudioParams,
StochasticRhythmAudioParams, CauKerGeneratorParams, ForecastPFNGeneratorParams,
KernelGeneratorParams
)
# Define fallback values for GIFT evaluation
ALL_DATASETS = ["ETTm1", "ETTm2", "ETTh1", "ETTh2", "Weather", "Electricity", "Traffic"]
TERMS = ["short", "medium", "long"]
# GIFT Evaluation imports (optional)
try:
from src.gift_eval.evaluate import evaluate_datasets
from src.gift_eval.predictor import TimeSeriesPredictor
from src.gift_eval.results import aggregate_results
from src.gift_eval.constants import ALL_DATASETS, TERMS
GIFT_EVAL_AVAILABLE = True
except ImportError:
GIFT_EVAL_AVAILABLE = False
print("Warning: GIFT evaluation dependencies not available. GIFT evaluation tab will be disabled.")
from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import SineWaveGeneratorWrapper
from src.synthetic_generation.anomalies.anomaly_generator_wrapper import AnomalyGeneratorWrapper
from src.synthetic_generation.sawtooth.sawtooth_generator_wrapper import SawToothGeneratorWrapper
from src.synthetic_generation.spikes.spikes_generator_wrapper import SpikesGeneratorWrapper
from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper
from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import OrnsteinUhlenbeckProcessGeneratorWrapper
# Try to import additional optional generators
try:
from src.synthetic_generation.audio_generators.network_topology_wrapper import NetworkTopologyAudioWrapper
NETWORK_AVAILABLE = True
except ImportError:
NETWORK_AVAILABLE = False
try:
from src.synthetic_generation.audio_generators.stochastic_rhythm_wrapper import StochasticRhythmAudioWrapper
RHYTHM_AVAILABLE = True
except ImportError:
RHYTHM_AVAILABLE = False
try:
from src.synthetic_generation.cauker.cauker_generator_wrapper import CauKerGeneratorWrapper
CAUKER_AVAILABLE = True
except ImportError:
CAUKER_AVAILABLE = False
try:
from src.synthetic_generation.forecast_pfn_prior.forecast_pfn_generator_wrapper import ForecastPFNGeneratorWrapper
FORECAST_PFN_AVAILABLE = True
except ImportError:
FORECAST_PFN_AVAILABLE = False
try:
from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import KernelGeneratorWrapper
KERNEL_AVAILABLE = True
except ImportError:
KERNEL_AVAILABLE = False
# Try to import optional generators
try:
from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper
GP_AVAILABLE = True
except ImportError:
GP_AVAILABLE = False
try:
from src.synthetic_generation.audio_generators.multi_scale_fractal_wrapper import MultiScaleFractalAudioWrapper
from src.synthetic_generation.audio_generators.financial_volatility_wrapper import FinancialVolatilityAudioWrapper
AUDIO_AVAILABLE = True
except ImportError:
AUDIO_AVAILABLE = False
# Define global placeholders (device is not needed - only used inside GPU function)
model = None
# Global variables to store forecast results for export
last_forecast_results = None
last_metrics_results = None
last_analysis_results = None
def create_gradio_app():
"""Create and configure the Gradio app for TempoPFN."""
@functools.lru_cache(maxsize=None)
def load_oil_price_data():
"""Downloads and caches daily WTI oil price data."""
print("--- Downloading WTI Oil Price data for the first time ---")
url = "https://datahub.io/core/oil-prices/r/wti-daily.csv"
try:
df = pd.read_csv(url)
df['Date'] = pd.to_datetime(df['Date'])
df = df.sort_values('Date')
df = df.set_index('Date').asfreq('D').ffill().reset_index()
values = df['Price'].values.astype(np.float32)
start_date = df['Date'].min()
print(f"--- Oil price data loaded. {len(values)} points ---")
return values, start_date, "D"
except Exception as e:
print(f"Error loading oil price data: {e}")
raise
def generate_synthetic_data(length=2048, seed=42):
"""Generate synthetic sine wave data for demonstration."""
sine_params = SineWaveGeneratorParams(global_seed=seed, length=length)
sine_generator = SineWaveGeneratorWrapper(sine_params)
batch = sine_generator.generate_batch(batch_size=1, seed=seed)
values = torch.from_numpy(batch.values).to(torch.float32)
if values.ndim == 2:
values = values.unsqueeze(-1)
# FIX: Use .squeeze() to return a 1D array to match expected logic flow (4D bugfix)
return values.squeeze().numpy(), batch.start[0], batch.frequency[0]
def process_uploaded_data(file):
"""Process uploaded CSV file with time series data."""
if file is None: return None, "No file uploaded"
try:
df = pd.read_csv(file.name)
if len(df.columns) < 2: return None, "CSV must have at least 2 columns"
time_col, value_col = df.columns[0], df.columns[1]
try:
df[time_col] = pd.to_datetime(df[time_col])
df = df.sort_values(time_col)
start_date = df[time_col].min()
freq = pd.infer_freq(df[time_col]) or "D"
except Exception:
start_date = np.datetime64("2020-01-01")
freq = "D"
values = df[value_col].values.astype(np.float32)
volumes = df['Volume'].values.astype(np.float32) if 'Volume' in df.columns else None
return values, volumes, start_date, freq, f"Loaded {len(values)} data points"
except Exception as e:
return None, None, None, None, f"Error processing file: {str(e)}"
def create_advanced_visualizations(history_values, predictions, future_values=None):
"""Create advanced statistical visualizations."""
try:
# Create subplots with multiple analyses
fig = make_subplots(
rows=2, cols=2,
subplot_titles=('Residual Analysis', 'ACF Plot',
'Distribution Comparison', 'Forecast Error Distribution'),
specs=[[{"type": "scatter"}, {"type": "bar"}],
[{"type": "histogram"}, {"type": "histogram"}]]
)
history_flat = history_values.flatten()
pred_flat = predictions.flatten()
# 1. Residual Analysis (if ground truth available)
if future_values is not None:
future_flat = future_values.flatten()[:len(pred_flat)]
residuals = future_flat - pred_flat
fig.add_trace(
go.Scatter(x=list(range(len(residuals))), y=residuals,
mode='lines+markers', name='Residuals'),
row=1, col=1
)
fig.add_hline(y=0, line_dash="dash", line_color="red", row=1, col=1)
else:
# Just show predictions
fig.add_trace(
go.Scatter(x=list(range(len(pred_flat))), y=pred_flat,
mode='lines', name='Predictions'),
row=1, col=1
)
# 2. Autocorrelation Function (ACF)
max_lags = min(40, len(history_flat) // 2)
acf_values = []
for lag in range(max_lags):
if lag == 0:
acf_values.append(1.0)
else:
acf = np.corrcoef(history_flat[:-lag], history_flat[lag:])[0, 1]
acf_values.append(acf)
fig.add_trace(
go.Bar(x=list(range(max_lags)), y=acf_values, name='ACF'),
row=1, col=2
)
# Confidence interval lines
ci = 1.96 / np.sqrt(len(history_flat))
fig.add_hline(y=ci, line_dash="dash", line_color="blue", row=1, col=2)
fig.add_hline(y=-ci, line_dash="dash", line_color="blue", row=1, col=2)
# 3. Distribution Comparison
fig.add_trace(
go.Histogram(x=history_flat, name='Historical', opacity=0.7, nbinsx=30),
row=2, col=1
)
fig.add_trace(
go.Histogram(x=pred_flat, name='Predictions', opacity=0.7, nbinsx=30),
row=2, col=1
)
# 4. Forecast Error Distribution (if ground truth available)
if future_values is not None:
future_flat = future_values.flatten()[:len(pred_flat)]
errors = future_flat - pred_flat
fig.add_trace(
go.Histogram(x=errors, name='Forecast Errors', nbinsx=30),
row=2, col=2
)
else:
# Show prediction distribution
fig.add_trace(
go.Histogram(x=pred_flat, name='Pred Distribution', nbinsx=30),
row=2, col=2
)
# Update layout
fig.update_layout(
height=800,
title_text="Advanced Statistical Analysis",
showlegend=True
)
fig.update_xaxes(title_text="Time Index", row=1, col=1)
fig.update_yaxes(title_text="Value", row=1, col=1)
fig.update_xaxes(title_text="Lag", row=1, col=2)
fig.update_yaxes(title_text="Correlation", row=1, col=2)
fig.update_xaxes(title_text="Value", row=2, col=1)
fig.update_yaxes(title_text="Frequency", row=2, col=1)
fig.update_xaxes(title_text="Error", row=2, col=2)
fig.update_yaxes(title_text="Frequency", row=2, col=2)
return fig
except Exception as e:
print(f"Error creating advanced visualizations: {e}")
# Return simple error figure
fig = go.Figure()
fig.add_annotation(
text=f"Error creating visualizations: {str(e)}",
xref="paper", yref="paper", x=0.5, y=0.5,
showarrow=False, font=dict(size=14, color="red")
)
return fig
def export_forecast_csv():
"""Export forecast data to CSV."""
global last_forecast_results
if last_forecast_results is None:
return None, "No forecast data available. Please run a forecast first."
try:
# Create DataFrame with forecast data
history = last_forecast_results['history'].flatten()
predictions = last_forecast_results['predictions'].flatten()
future = last_forecast_results['future'].flatten()
max_len = max(len(history), len(predictions))
df_data = {
'Time_Index': list(range(max_len)),
'Historical_Value': list(history) + [np.nan] * (max_len - len(history)),
'Predicted_Value': [np.nan] * len(history) + list(predictions[:max_len - len(history)]),
'True_Future_Value': [np.nan] * len(history) + list(future[:max_len - len(history)])
}
df = pd.DataFrame(df_data)
filepath = "/tmp/forecast_data.csv"
df.to_csv(filepath, index=False)
return filepath, "Forecast data exported successfully!"
except Exception as e:
return None, f"Error exporting forecast data: {str(e)}"
def export_metrics_csv():
"""Export metrics summary to CSV."""
global last_metrics_results
if last_metrics_results is None:
return None, "No metrics available. Please run a forecast first."
try:
df = pd.DataFrame([last_metrics_results])
filepath = "/tmp/metrics_summary.csv"
df.to_csv(filepath, index=False)
return filepath, "Metrics summary exported successfully!"
except Exception as e:
return None, f"Error exporting metrics: {str(e)}"
def export_analysis_csv():
"""Export full analysis including forecast, metrics, and metadata."""
global last_forecast_results, last_metrics_results, last_analysis_results
if last_forecast_results is None:
return None, "No analysis data available. Please run a forecast first."
try:
# Combine all data
analysis_data = {
**last_analysis_results,
**last_metrics_results,
'num_history_points': len(last_forecast_results['history']),
'num_forecast_points': len(last_forecast_results['predictions']),
}
df = pd.DataFrame([analysis_data])
filepath = "/tmp/full_analysis.csv"
df.to_csv(filepath, index=False)
return filepath, "Full analysis exported successfully!"
except Exception as e:
return None, f"Error exporting analysis: {str(e)}"
def calculate_metrics(history_values, predictions, future_values=None, data_source=""):
"""Calculate comprehensive metrics for display in the UI."""
metrics = {}
# Basic statistics
metrics['data_mean'] = float(np.mean(history_values))
metrics['data_std'] = float(np.std(history_values))
metrics['data_skewness'] = float(stats.skew(history_values.flatten()))
metrics['data_kurtosis'] = float(stats.kurtosis(history_values.flatten()))
# Latest values and forecasts
metrics['latest_price'] = float(history_values[-1, 0] if history_values.ndim > 1 else history_values[-1])
metrics['forecast_next'] = float(predictions[0, 0] if predictions.ndim > 1 else predictions[0])
# Volatility (30-day rolling std as percentage of mean)
if len(history_values) >= 30:
recent_30 = history_values[-30:].flatten()
volatility = (np.std(recent_30) / np.mean(recent_30)) * 100 if np.mean(recent_30) != 0 else 0
metrics['vol_30d'] = float(volatility)
else:
metrics['vol_30d'] = 0.0
# 52-week high/low (or max/min of available data)
lookback = min(252, len(history_values)) # 252 trading days ≈ 1 year
recent_data = history_values[-lookback:].flatten()
metrics['high_52wk'] = float(np.max(recent_data))
metrics['low_52wk'] = float(np.min(recent_data))
# Time series properties
# Autocorrelation at lag 1
if len(history_values) > 1:
flat_history = history_values.flatten()
metrics['data_autocorr'] = float(np.corrcoef(flat_history[:-1], flat_history[1:])[0, 1])
else:
metrics['data_autocorr'] = 0.0
# Stationarity test (simplified - using rolling mean variance)
if len(history_values) >= 20:
first_half = history_values[:len(history_values)//2].flatten()
second_half = history_values[len(history_values)//2:].flatten()
var_ratio = np.var(second_half) / np.var(first_half) if np.var(first_half) > 0 else 1.0
metrics['data_stationary'] = "Likely" if 0.5 < var_ratio < 2.0 else "Unlikely"
else:
metrics['data_stationary'] = "Unknown"
# Pattern detection (simple heuristic)
if metrics['data_autocorr'] > 0.7:
metrics['pattern_type'] = "Trending"
elif abs(metrics['data_autocorr']) < 0.3:
metrics['pattern_type'] = "Random Walk"
else:
metrics['pattern_type'] = "Mean Reverting"
# Performance metrics (if ground truth available)
if future_values is not None:
pred_flat = predictions.flatten()[:len(future_values.flatten())]
true_flat = future_values.flatten()[:len(pred_flat)]
# MSE, MAE
metrics['mse'] = float(np.mean((pred_flat - true_flat) ** 2))
metrics['mae'] = float(np.mean(np.abs(pred_flat - true_flat)))
# MAPE (avoiding division by zero)
mape_values = np.abs((true_flat - pred_flat) / (true_flat + 1e-8)) * 100
metrics['mape'] = float(np.mean(mape_values))
else:
metrics['mse'] = 0.0
metrics['mae'] = 0.0
metrics['mape'] = 0.0
# Uncertainty quantification placeholders (would need quantile predictions)
metrics['coverage_80'] = 0.0
metrics['coverage_95'] = 0.0
metrics['calibration'] = 0.0
# Information theory metrics (simplified)
# Sample entropy approximation
try:
hist_normalized = (history_values.flatten() - np.mean(history_values)) / (np.std(history_values) + 1e-8)
metrics['sample_entropy'] = float(-np.mean(np.log(np.abs(hist_normalized) + 1e-8)))
except:
metrics['sample_entropy'] = 0.0
metrics['approx_entropy'] = metrics['sample_entropy'] * 0.8 # Placeholder
metrics['perm_entropy'] = metrics['sample_entropy'] * 0.9 # Placeholder
# Complexity measures
# Fractal dimension (box-counting approximation)
try:
metrics['fractal_dim'] = float(1.0 + 0.5 * metrics['data_std'] / (np.mean(np.abs(np.diff(history_values.flatten()))) + 1e-8))
except:
metrics['fractal_dim'] = 1.5
# Spectral features
try:
# FFT-based features
fft_vals = np.fft.fft(history_values.flatten())
power_spectrum = np.abs(fft_vals[:len(fft_vals)//2]) ** 2
freqs = np.fft.fftfreq(len(history_values.flatten()))[:len(fft_vals)//2]
# Dominant frequency
dominant_idx = np.argmax(power_spectrum[1:]) + 1 # Skip DC component
metrics['dominant_freq'] = float(abs(freqs[dominant_idx]))
# Spectral centroid
metrics['spectral_centroid'] = float(np.sum(freqs * power_spectrum) / (np.sum(power_spectrum) + 1e-8))
# Spectral entropy
power_normalized = power_spectrum / (np.sum(power_spectrum) + 1e-8)
metrics['spectral_entropy'] = float(-np.sum(power_normalized * np.log(power_normalized + 1e-8)))
except:
metrics['dominant_freq'] = 0.0
metrics['spectral_centroid'] = 0.0
metrics['spectral_entropy'] = 0.0
# Cross-validation placeholders
metrics['cv_mse'] = 0.0
metrics['cv_mae'] = 0.0
metrics['cv_windows'] = 0
# Sensitivity placeholders
metrics['horizon_sensitivity'] = 0.0
metrics['history_sensitivity'] = 0.0
metrics['stability_score'] = 0.0
return metrics
@spaces.GPU(duration=120) # Extend timeout to 120 seconds for first-run compilation
def run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object):
"""
GPU-only inference function for ZeroGPU Spaces.
ALL CUDA operations must happen inside this decorated function.
Extended timeout for Triton kernel compilation on first run.
"""
global model
# Load model once on first call (on CPU first to save GPU time)
if model is None:
print("--- Loading TempoPFN model for the first time ---")
print(f"Downloading model...")
model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth")
# Load on CPU first to save GPU allocation time
print(f"Loading model from {model_path} to CPU first...")
model = load_model(config_path="configs/example.yaml", model_path=model_path, device=torch.device("cpu"))
print("--- Model loaded successfully on CPU ---")
# Move model to GPU inside the decorated function
device = torch.device("cuda:0")
print(f"Moving model to {device}...")
model.to(device)
# Prepare container with GPU tensors
container = BatchTimeSeriesContainer(
history_values=history_values_tensor.to(device),
future_values=future_values_tensor.to(device),
start=[start],
frequency=[freq_object],
)
# Run inference with bfloat16 autocast
print("Running inference...")
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
model_output = model(container)
# Move model back to CPU to free GPU memory
model.to(torch.device("cpu"))
print("Inference complete, model moved back to CPU")
return model_output
def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
"""
Runs the TempoPFN forecast.
Returns: history_price, history_volume, predictions, quantiles, plot, status, metrics, data_preview
"""
try:
all_volumes = None
if data_source == "Stock Ticker":
if not stock_ticker:
return None, None, None, None, "Please enter a stock ticker (e.g., SPY, AAPL)"
print(f"--- Downloading '{stock_ticker}' data from yfinance ---")
hist = yf.download(stock_ticker, period="max", auto_adjust=True)
if hist.empty:
return None, None, None, None, f"Could not find data for ticker '{stock_ticker}'"
hist = hist[['Close', 'Volume']].asfreq('D').ffill()
# --- FIX: Squeeze to ensure 1D array from pandas Series/DataFrame columns (4D bugfix) ---
all_values = hist['Close'].values.astype(np.float32).squeeze()
all_volumes = hist['Volume'].values.astype(np.float32).squeeze()
data_start_date = hist.index.min()
frequency = "D"
elif data_source == "VIX Volatility Index":
print("--- Downloading VIX data from yfinance ---")
vix_data = yf.download("^VIX", period="max", auto_adjust=True)
if vix_data.empty:
return None, None, None, None, "Could not download VIX data"
vix_data = vix_data.asfreq('D').ffill()
all_values = vix_data['Close'].values.astype(np.float32).squeeze()
data_start_date = vix_data.index.min()
frequency = "D"
print(f"--- VIX data loaded: {len(all_values)} points ---")
elif data_source == "Default (WTI Oil Prices)":
all_values, data_start_date, frequency = load_oil_price_data()
elif data_source == "Upload Custom CSV":
all_values, all_volumes, data_start_date, frequency, message = process_uploaded_data(uploaded_file)
if all_values is None:
return None, None, None, None, message
elif data_source == "Synthetic Playground":
print(f"--- Generating {synth_generator} synthetic data (complexity: {synth_complexity}) ---")
# Generate synthetic data based on selected generator
total_length = history_length + forecast_horizon
if synth_generator == "Sine Waves":
params = SineWaveGeneratorParams(global_seed=seed, length=total_length)
generator = SineWaveGeneratorWrapper(params)
elif synth_generator == "Sawtooth Waves":
params = SawToothGeneratorParams(global_seed=seed, length=total_length)
generator = SawToothGeneratorWrapper(params)
elif synth_generator == "Spikes":
params = SpikesGeneratorParams(global_seed=seed, length=total_length)
generator = SpikesGeneratorWrapper(params)
elif synth_generator == "Steps":
params = StepGeneratorParams(global_seed=seed, length=total_length)
generator = StepGeneratorWrapper(params)
elif synth_generator == "Ornstein-Uhlenbeck":
params = OrnsteinUhlenbeckProcessGeneratorParams(global_seed=seed, length=total_length)
generator = OrnsteinUhlenbeckProcessGeneratorWrapper(params)
elif synth_generator == "Gaussian Processes" and GP_AVAILABLE:
params = GPGeneratorParams(global_seed=seed, length=total_length)
generator = GPGeneratorWrapper(params)
elif synth_generator == "Anomaly Patterns":
params = AnomalyGeneratorParams(global_seed=seed, length=total_length)
generator = AnomalyGeneratorWrapper(params)
elif synth_generator == "Financial Volatility" and AUDIO_AVAILABLE:
params = FinancialVolatilityAudioParams(global_seed=seed, length=total_length)
generator = FinancialVolatilityAudioWrapper(params)
elif synth_generator == "Fractal Patterns" and AUDIO_AVAILABLE:
params = MultiScaleFractalAudioParams(global_seed=seed, length=total_length)
generator = MultiScaleFractalAudioWrapper(params)
elif synth_generator == "Network Topology" and NETWORK_AVAILABLE:
params = NetworkTopologyAudioParams(global_seed=seed, length=total_length)
generator = NetworkTopologyAudioWrapper(params)
elif synth_generator == "Stochastic Rhythm" and RHYTHM_AVAILABLE:
params = StochasticRhythmAudioParams(global_seed=seed, length=total_length)
generator = StochasticRhythmAudioWrapper(params)
elif synth_generator == "CauKer" and CAUKER_AVAILABLE:
params = CauKerGeneratorParams(global_seed=seed, length=total_length)
generator = CauKerGeneratorWrapper(params)
elif synth_generator == "Forecast PFN Prior" and FORECAST_PFN_AVAILABLE:
params = ForecastPFNGeneratorParams(global_seed=seed, length=total_length)
generator = ForecastPFNGeneratorWrapper(params)
elif synth_generator == "Kernel Synth" and KERNEL_AVAILABLE:
params = KernelGeneratorParams(global_seed=seed, length=total_length)
generator = KernelGeneratorWrapper(params)
else:
# Fallback to sine waves if generator not available
params = SineWaveGeneratorParams(global_seed=seed, length=total_length)
generator = SineWaveGeneratorWrapper(params)
# Generate the batch
batch = generator.generate_batch(batch_size=1, seed=seed)
values = torch.from_numpy(batch.values).to(torch.float32)
if values.ndim == 2:
values = values.unsqueeze(-1)
all_values = values.squeeze().numpy()
data_start_date = batch.start[0] if hasattr(batch, 'start') and batch.start else np.datetime64("2020-01-01")
frequency = batch.frequency[0] if hasattr(batch, 'frequency') and batch.frequency else "D"
print(f"--- {synth_generator} data generated: {len(all_values)} points ---")
else: # "Synthetic Data"
values, start, frequency = generate_synthetic_data(length=history_length + forecast_horizon, seed=seed)
all_values, data_start_date = values, start
# --- Common Logic for Slicing Data ---
if data_source != "Synthetic Data":
total_needed = history_length + forecast_horizon
if len(all_values) < total_needed:
return None, None, None, None, f"Data has {len(all_values)} points, but {total_needed} are needed."
values = all_values[-total_needed:]
start_offset_days = len(all_values) - total_needed
start = np.datetime64(data_start_date) + np.timedelta64(start_offset_days, 'D')
if all_volumes is not None:
history_volumes = all_volumes[-(total_needed) : -forecast_horizon]
else:
history_volumes = np.array([np.nan] * history_length)
else:
start = data_start_date
history_volumes = np.array([np.nan] * history_length)
# --- Prepare data for model ---
# Unsqueeze calls convert the 1D array into the required [B, S, N] shape: [1, S, 1]
values_tensor = torch.from_numpy(values).unsqueeze(0).unsqueeze(-1)
future_length = forecast_horizon
# --- Convert string to the correct Frequency enum ---
if isinstance(frequency, str):
if frequency.startswith("D"):
freq_object = Frequency.D
elif frequency.startswith("W"):
freq_object = Frequency.W
elif frequency.startswith("M"):
freq_object = Frequency.M
elif frequency.startswith("Q"):
freq_object = Frequency.Q
elif frequency.startswith("A") or frequency.startswith("Y"):
freq_object = Frequency.A
else:
print(f"Warning: Unknown frequency string '{frequency}'. Defaulting to Daily.")
freq_object = Frequency.D
else:
freq_object = frequency
# Prepare container for GPU inference
history_values_tensor = values_tensor[:, :-future_length, :]
future_values_tensor = values_tensor[:, -future_length:, :]
# Ensure start is np.datetime64
if not isinstance(start, np.datetime64):
start = np.datetime64(start)
# Run GPU inference (all CUDA ops happen inside the decorated function)
# Pass CPU tensors - they will be moved to GPU inside the function
model_output = run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object)
# Post-process predictions (exactly like examples/utils.py lines 65-69)
preds_full = model_output["result"].to(torch.float32)
if model is not None and hasattr(model, "scaler") and "scale_statistics" in model_output:
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
# Convert to numpy for plotting
preds_np = preds_full.detach().cpu().numpy()
history_np = history_values_tensor.cpu().numpy().squeeze(0)
future_np = future_values_tensor.cpu().numpy().squeeze(0)
preds_squeezed = preds_np.squeeze(0)
# Get model quantiles if available
model_quantiles = None
if model is not None and hasattr(model, "loss_type") and model.loss_type == "quantile":
model_quantiles = model.quantiles
try:
forecast_plot = plot_multivariate_timeseries(
history_values=history_np,
future_values=future_np,
predicted_values=preds_squeezed,
start=start,
frequency=freq_object,
title=f"TempoPFN Forecast - {data_source}",
show=False # Don't show the plot, we'll display in Gradio
)
except Exception as plot_error:
print(f"Warning: Failed to generate plot: {plot_error}")
# Create a simple error plot
import plotly.graph_objects as go
forecast_plot = go.Figure()
forecast_plot.add_annotation(
text="Plot generation failed",
xref="paper", yref="paper", x=0.5, y=0.5,
showarrow=False, font=dict(size=14, color="red")
)
# Calculate comprehensive metrics
metrics = calculate_metrics(
history_values=history_np,
predictions=preds_squeezed,
future_values=future_np,
data_source=data_source
)
# Store results globally for export functionality
global last_forecast_results, last_metrics_results, last_analysis_results
last_forecast_results = {
'history': history_np,
'predictions': preds_squeezed,
'future': future_np,
'start': start,
'frequency': freq_object
}
last_metrics_results = metrics
last_analysis_results = {
'data_source': data_source,
'forecast_horizon': forecast_horizon,
'history_length': history_length,
'seed': seed
}
# Create data preview DataFrame
preview_data = {
'Index': list(range(len(history_np))),
'Historical Value': history_np.flatten()[:100] # Limit to first 100 for display
}
if history_volumes is not None and not np.all(np.isnan(history_volumes)):
preview_data['Volume'] = history_volumes[:100]
data_preview_df = pd.DataFrame(preview_data)
return (
history_np, history_volumes, preds_squeezed, model_quantiles,
forecast_plot, "Forecasting completed successfully!",
metrics, data_preview_df
)
except Exception as e:
traceback.print_exc()
error_msg = f"Error during forecasting: {str(e)}"
empty_metrics = {k: 0.0 if isinstance(v, float) else "" for k, v in
calculate_metrics(np.array([0.0]), np.array([0.0])).items()}
return None, None, None, None, None, error_msg, empty_metrics, pd.DataFrame()
# --- [GRADIO UI - Simplified with Default Styling] ---
with gr.Blocks(title="TempoPFN") as app:
gr.Markdown("# TempoPFN\n### Zero-Shot Forecasting & Analysis Terminal\n*Powered by synthetic pre-training • Forecast anything, anywhere*")
gr.Markdown("⚠️ **First Run Note**: Initial inference may take 60-90 seconds due to Triton kernel compilation. Subsequent runs will be much faster!")
with gr.Tabs() as tabs:
# ===== FINANCIAL MARKETS TAB =====
with gr.TabItem("Financial Markets", id="financial"):
with gr.Row():
with gr.Column(scale=1, min_width=380):
# Data Source Section
gr.Markdown("### Financial Data Sources")
financial_source = gr.Radio(
choices=["Default (WTI Oil Prices)", "Stock Ticker", "VIX Volatility Index", "Upload Custom CSV"],
value="Default (WTI Oil Prices)",
label="",
info="Choose financial market data or upload your own"
)
# Combine the selections
data_source = gr.Textbox(visible=False)
# Dynamic inputs
with gr.Row():
stock_ticker = gr.Textbox(
label="Stock Ticker",
value="SPY",
placeholder="e.g., SPY, AAPL, TSLA",
visible=False
)
uploaded_file = gr.File(
label="CSV File",
file_types=[".csv"],
visible=False
)
def toggle_financial_input(choice):
show_ticker = (choice == "Stock Ticker")
show_upload = (choice == "Upload Custom CSV")
return (
gr.update(visible=show_ticker),
gr.update(visible=show_upload)
)
# Handle selection changes
financial_source.change(
fn=lambda x: x, # Just pass through the selection
inputs=financial_source,
outputs=data_source
).then(
fn=toggle_financial_input,
inputs=financial_source,
outputs=[stock_ticker, uploaded_file]
)
# Forecasting Parameters Section
gr.Markdown("### Forecasting Parameters")
forecast_horizon = gr.Slider(
minimum=30, maximum=512, value=90, step=1,
label="Forecast Horizon",
info="Number of periods to forecast ahead"
)
history_length = gr.Slider(
minimum=256, maximum=2048, value=1024, step=8,
label="History Length",
info="Historical data points to analyze"
)
financial_forecast_btn = gr.Button("Run Forecast & Analysis")
with gr.Column(scale=3):
# Status Section
gr.Markdown("### Analysis Results")
status_text = gr.Textbox(
label="",
interactive=False,
lines=3,
info="Forecasting progress and results"
)
# Key Metrics Section (Adaptive based on data source)
gr.Markdown("### Key Metrics")
# Financial metrics (shown for financial data)
with gr.Row(visible=True) as financial_metrics:
with gr.Column():
gr.Markdown("**Latest Level:** $<span id='latest-price'>0.00</span>")
latest_price_out = gr.Number(visible=False)
gr.Markdown("**Forecast (Next Period):** $<span id='forecast-next'>0.00</span>")
forecast_next_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**30-Day Volatility:** <span id='vol-30d'>0.00</span>%")
vol_30d_out = gr.Number(visible=False)
with gr.Row():
gr.Markdown("**52-Week High:** $<span id='high-52wk'>0.00</span>")
high_52wk_out = gr.Number(visible=False)
gr.Markdown("**52-Week Low:** $<span id='low-52wk'>0.00</span>")
low_52wk_out = gr.Number(visible=False)
# Comprehensive Research Metrics (shown for synthetic data)
with gr.Row(visible=False) as synthetic_metrics:
with gr.Column():
gr.Markdown("**Statistical Properties:**")
gr.Markdown("• **Mean:** <span id='data-mean'>0.000</span>")
data_mean_out = gr.Number(visible=False)
gr.Markdown("• **Std Dev:** <span id='data-std'>0.000</span>")
data_std_out = gr.Number(visible=False)
gr.Markdown("• **Skewness:** <span id='data-skewness'>0.000</span>")
data_skewness_out = gr.Number(visible=False)
gr.Markdown("• **Kurtosis:** <span id='data-kurtosis'>0.000</span>")
data_kurtosis_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**Time Series Analysis:**")
gr.Markdown("• **Autocorr (lag-1):** <span id='data-autocorr'>0.000</span>")
data_autocorr_out = gr.Number(visible=False)
gr.Markdown("• **Stationary:** <span id='data-stationary'>Unknown</span>")
data_stationary_out = gr.Textbox(visible=False)
gr.Markdown("• **Pattern Type:** <span id='pattern-type'>None</span>")
pattern_type_out = gr.Textbox(visible=False)
# Model Performance Metrics
with gr.Row(visible=False) as performance_metrics:
with gr.Column():
gr.Markdown("**Forecast Performance:**")
gr.Markdown("• **MSE:** <span id='mse'>0.000</span>")
mse_out = gr.Number(visible=False)
gr.Markdown("• **MAE:** <span id='mae'>0.000</span>")
mae_out = gr.Number(visible=False)
gr.Markdown("• **MAPE:** <span id='mape'>0.000</span>%")
mape_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**Uncertainty Quantification:**")
gr.Markdown("• **80% Coverage:** <span id='coverage-80'>0.000</span>")
coverage_80_out = gr.Number(visible=False)
gr.Markdown("• **95% Coverage:** <span id='coverage-95'>0.000</span>")
coverage_95_out = gr.Number(visible=False)
gr.Markdown("• **Calibration:** <span id='calibration'>0.000</span>")
calibration_out = gr.Number(visible=False)
# Data Complexity Metrics
with gr.Row(visible=False) as complexity_metrics:
with gr.Column():
gr.Markdown("**Information Theory:**")
gr.Markdown("• **Sample Entropy:** <span id='sample-entropy'>0.000</span>")
sample_entropy_out = gr.Number(visible=False)
gr.Markdown("• **Approx Entropy:** <span id='approx-entropy'>0.000</span>")
approx_entropy_out = gr.Number(visible=False)
gr.Markdown("• **Perm Entropy:** <span id='perm-entropy'>0.000</span>")
perm_entropy_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**Complexity Measures:**")
gr.Markdown("• **Fractal Dim:** <span id='fractal-dim'>0.000</span>")
fractal_dim_out = gr.Number(visible=False)
gr.Markdown("• **Dominant Freq:** <span id='dominant-freq'>0.000</span>")
dominant_freq_out = gr.Number(visible=False)
gr.Markdown("• **Spectral Centroid:** <span id='spectral-centroid'>0.000</span>")
spectral_centroid_out = gr.Number(visible=False)
gr.Markdown("• **Spectral Entropy:** <span id='spectral-entropy'>0.000</span>")
spectral_entropy_out = gr.Number(visible=False)
# Research Tools Section
with gr.Row(visible=False) as research_tools:
with gr.Column():
gr.Markdown("**Cross-Validation Results:**")
gr.Markdown("• **Rolling Window MSE:** <span id='cv-mse'>0.000</span>")
cv_mse_out = gr.Number(visible=False)
gr.Markdown("• **Rolling Window MAE:** <span id='cv-mae'>0.000</span>")
cv_mae_out = gr.Number(visible=False)
gr.Markdown("• **Validation Windows:** <span id='cv-windows'>0</span>")
cv_windows_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**Parameter Sensitivity:**")
gr.Markdown("• **Horizon Sensitivity:** <span id='horizon-sensitivity'>0.000</span>")
horizon_sensitivity_out = gr.Number(visible=False)
gr.Markdown("• **History Sensitivity:** <span id='history-sensitivity'>0.000</span>")
history_sensitivity_out = gr.Number(visible=False)
gr.Markdown("• **Stability Score:** <span id='stability-score'>0.000</span>")
stability_score_out = gr.Number(visible=False)
# Forecast Visualization Section
gr.Markdown("### Forecast & Technical Analysis")
plot_output = gr.Plot(
label="",
show_label=False
)
# Advanced Visualizations Section
with gr.Accordion("Advanced Statistical Visualizations", open=False):
advanced_plots = gr.Plot(label="", show_label=False)
# Export & Analysis Tools Section
with gr.Accordion("Export & Analysis Tools", open=False):
with gr.Row():
export_forecast_csv = gr.Button("📊 Export Forecast Data (CSV)")
export_metrics_csv = gr.Button("📈 Export Metrics Summary (CSV)")
export_analysis_csv = gr.Button("🔬 Export Full Analysis (CSV)")
export_status = gr.Textbox(
label="Export Status",
interactive=False,
lines=2,
info="Export operation results"
)
export_file = gr.File(
label="Download Exported Data",
visible=False
)
# Data Preview Section
with gr.Accordion("Raw Data Preview", open=False):
data_preview = gr.Dataframe(
label="",
show_label=False,
wrap=True
)
# ===== RESEARCH & ANALYSIS TAB =====
with gr.TabItem("Research & Analysis", id="research"):
with gr.Row():
with gr.Column(scale=1, min_width=380):
# Data Source Section
gr.Markdown("### Synthetic Data Testing")
research_source = gr.Radio(
choices=["Basic Synthetic", "Advanced Synthetic"],
value="Basic Synthetic",
label="",
info="Test TempoPFN with synthetic data patterns"
)
# Dynamic inputs for research tab
seed = gr.Number(
value=42,
label="Random Seed",
minimum=0,
maximum=9999,
step=1,
visible=False
)
# Build available generator choices
available_generators = [
"Sine Waves", "Sawtooth Waves", "Spikes", "Steps",
"Ornstein-Uhlenbeck", "Anomaly Patterns"
]
if GP_AVAILABLE:
available_generators.append("Gaussian Processes")
if AUDIO_AVAILABLE:
available_generators.extend(["Financial Volatility", "Fractal Patterns"])
if NETWORK_AVAILABLE:
available_generators.append("Network Topology")
if RHYTHM_AVAILABLE:
available_generators.append("Stochastic Rhythm")
if CAUKER_AVAILABLE:
available_generators.append("CauKer")
if FORECAST_PFN_AVAILABLE:
available_generators.append("Forecast PFN Prior")
if KERNEL_AVAILABLE:
available_generators.append("Kernel Synth")
# Synthetic Playground controls
with gr.Row():
synth_generator = gr.Dropdown(
choices=available_generators,
value="Sine Waves",
label="Generator Type",
visible=False,
info="Select synthetic pattern generator"
)
synth_complexity = gr.Slider(
minimum=1, maximum=10, value=5, step=1,
label="Complexity",
visible=False,
info="Pattern complexity level"
)
def toggle_research_input(choice):
show_seed = (choice == "Basic Synthetic")
show_synth = (choice == "Advanced Synthetic")
return (
gr.update(visible=show_seed),
gr.update(visible=show_synth),
gr.update(visible=show_synth)
)
# Handle selection changes
research_source.change(
fn=lambda x: x, # Just pass through the selection
inputs=research_source,
outputs=data_source
).then(
fn=toggle_research_input,
inputs=research_source,
outputs=[seed, synth_generator, synth_complexity]
)
# Forecasting Parameters Section
gr.Markdown("### Forecasting Parameters")
forecast_horizon = gr.Slider(
minimum=30, maximum=512, value=90, step=1,
label="Forecast Horizon",
info="Number of periods to forecast ahead"
)
history_length = gr.Slider(
minimum=256, maximum=2048, value=1024, step=8,
label="History Length",
info="Historical data points to analyze"
)
forecast_btn = gr.Button("Run Forecast & Analysis")
with gr.Column(scale=3):
# Status Section
gr.Markdown("### Analysis Results")
research_status_text = gr.Textbox(
label="",
interactive=False,
lines=3,
info="Forecasting progress and results"
)
# Key Metrics Section (Adaptive based on data source)
gr.Markdown("### Key Metrics")
# Financial metrics (shown for financial data)
with gr.Row(visible=True) as financial_metrics:
with gr.Column():
gr.Markdown("**Latest Level:** $<span id='latest-price'>0.00</span>")
latest_price_out = gr.Number(visible=False)
gr.Markdown("**Forecast (Next Period):** $<span id='forecast-next'>0.00</span>")
forecast_next_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**30-Day Volatility:** <span id='vol-30d'>0.00</span>%")
vol_30d_out = gr.Number(visible=False)
with gr.Row():
gr.Markdown("**52-Week High:** $<span id='high-52wk'>0.00</span>")
high_52wk_out = gr.Number(visible=False)
gr.Markdown("**52-Week Low:** $<span id='low-52wk'>0.00</span>")
low_52wk_out = gr.Number(visible=False)
# Synthetic/Research metrics (shown for synthetic data)
with gr.Row(visible=False) as synthetic_metrics:
with gr.Column():
gr.Markdown("**Data Mean:** <span id='data-mean'>0.000</span>")
data_mean_out = gr.Number(visible=False)
gr.Markdown("**Data Std Dev:** <span id='data-std'>0.000</span>")
data_std_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**Forecast Horizon:** <span id='forecast-horizon'>0</span>")
forecast_accuracy_out = gr.Number(visible=False)
gr.Markdown("**Pattern Type:** <span id='pattern-type'>None</span>")
pattern_type_out = gr.Textbox(visible=False)
# Model Performance Metrics
with gr.Row(visible=False) as performance_metrics:
with gr.Column():
gr.Markdown("**Forecast Performance:**")
gr.Markdown("• **MSE:** <span id='mse'>0.000</span>")
mse_out = gr.Number(visible=False)
gr.Markdown("• **MAE:** <span id='mae'>0.000</span>")
mae_out = gr.Number(visible=False)
gr.Markdown("• **MAPE:** <span id='mape'>0.000</span>%")
mape_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**Uncertainty Quantification:**")
gr.Markdown("• **80% Coverage:** <span id='coverage-80'>0.000</span>")
coverage_80_out = gr.Number(visible=False)
gr.Markdown("• **95% Coverage:** <span id='coverage-95'>0.000</span>")
coverage_95_out = gr.Number(visible=False)
gr.Markdown("• **Calibration:** <span id='calibration'>0.000</span>")
calibration_out = gr.Number(visible=False)
# Data Complexity Metrics
with gr.Row(visible=False) as complexity_metrics:
with gr.Column():
gr.Markdown("**Information Theory:**")
gr.Markdown("• **Sample Entropy:** <span id='sample-entropy'>0.000</span>")
sample_entropy_out = gr.Number(visible=False)
gr.Markdown("• **Approx Entropy:** <span id='approx-entropy'>0.000</span>")
approx_entropy_out = gr.Number(visible=False)
gr.Markdown("• **Perm Entropy:** <span id='perm-entropy'>0.000</span>")
perm_entropy_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**Complexity Measures:**")
gr.Markdown("• **Fractal Dim:** <span id='fractal-dim'>0.000</span>")
fractal_dim_out = gr.Number(visible=False)
gr.Markdown("• **Dominant Freq:** <span id='dominant-freq'>0.000</span>")
dominant_freq_out = gr.Number(visible=False)
gr.Markdown("• **Spectral Centroid:** <span id='spectral-centroid'>0.000</span>")
spectral_centroid_out = gr.Number(visible=False)
gr.Markdown("• **Spectral Entropy:** <span id='spectral-entropy'>0.000</span>")
spectral_entropy_out = gr.Number(visible=False)
# Research Tools Section
with gr.Row(visible=False) as research_tools:
with gr.Column():
gr.Markdown("**Cross-Validation Results:**")
gr.Markdown("• **Rolling Window MSE:** <span id='cv-mse'>0.000</span>")
cv_mse_out = gr.Number(visible=False)
gr.Markdown("• **Rolling Window MAE:** <span id='cv-mae'>0.000</span>")
cv_mae_out = gr.Number(visible=False)
gr.Markdown("• **Validation Windows:** <span id='cv-windows'>0</span>")
cv_windows_out = gr.Number(visible=False)
with gr.Column():
gr.Markdown("**Parameter Sensitivity:**")
gr.Markdown("• **Horizon Sensitivity:** <span id='horizon-sensitivity'>0.000</span>")
horizon_sensitivity_out = gr.Number(visible=False)
gr.Markdown("• **History Sensitivity:** <span id='history-sensitivity'>0.000</span>")
history_sensitivity_out = gr.Number(visible=False)
gr.Markdown("• **Stability Score:** <span id='stability-score'>0.000</span>")
stability_score_out = gr.Number(visible=False)
# Forecast Visualization Section
gr.Markdown("### Forecast & Technical Analysis")
research_plot_output = gr.Plot(
label="",
show_label=False
)
# Advanced Visualizations Section (Research tab doesn't have this defined, so add it)
with gr.Accordion("Advanced Statistical Visualizations", open=False):
research_advanced_plots = gr.Plot(label="", show_label=False)
# Data Preview Section
with gr.Accordion("Raw Data Preview", open=False):
research_data_preview = gr.Dataframe(
label="",
show_label=False,
wrap=True
)
# Now add the metrics toggle function after components are defined
def toggle_metrics_display(choice):
"""Toggle between financial and synthetic metrics based on data source"""
show_financial = choice in ["Stock Ticker", "Default (WTI Oil Prices)", "VIX Volatility Index"]
show_synthetic = choice in ["Basic Synthetic", "Advanced Synthetic", "Upload Custom CSV"]
show_performance = show_synthetic # Show performance metrics for synthetic data
show_complexity = show_synthetic # Show complexity metrics for synthetic data
return (
gr.update(visible=show_financial),
gr.update(visible=show_synthetic),
gr.update(visible=show_performance),
gr.update(visible=show_complexity)
)
# Add the metrics toggle to the selection change handlers
financial_source.change(
fn=toggle_metrics_display,
inputs=data_source,
outputs=[financial_metrics, synthetic_metrics, performance_metrics, complexity_metrics]
)
research_source.change(
fn=toggle_metrics_display,
inputs=data_source,
outputs=[financial_metrics, synthetic_metrics, performance_metrics, complexity_metrics]
)
# Wrapper function to unpack forecast results for UI
def forecast_and_display_financial(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed):
result = forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, "Sine Waves", 5)
if result[5] and "Error" not in result[5]: # Check status
history_np = result[0]
preds = result[2]
future_np = last_forecast_results['future'] if last_forecast_results else None
# Generate advanced visualizations
adv_viz = create_advanced_visualizations(history_np, preds, future_np)
return (
result[5], # status_text
result[4], # plot_output
result[7], # data_preview
adv_viz # advanced_plots
)
else:
return result[5], None, pd.DataFrame(), go.Figure()
def forecast_and_display_research(data_source, forecast_horizon, history_length, seed, synth_generator, synth_complexity):
result = forecast_time_series(data_source, "", None, forecast_horizon, history_length, seed, synth_generator, synth_complexity)
if result[5] and "Error" not in result[5]:
history_np = result[0]
preds = result[2]
future_np = last_forecast_results['future'] if last_forecast_results else None
# Generate advanced visualizations
adv_viz = create_advanced_visualizations(history_np, preds, future_np)
return (
result[5], # status_text
result[4], # plot_output
result[7], # data_preview
adv_viz # advanced_plots
)
else:
return result[5], None, pd.DataFrame(), go.Figure()
# Connect button click handlers
financial_forecast_btn.click(
fn=forecast_and_display_financial,
inputs=[data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed],
outputs=[status_text, plot_output, data_preview, advanced_plots]
)
forecast_btn.click(
fn=forecast_and_display_research,
inputs=[data_source, forecast_horizon, history_length, seed, synth_generator, synth_complexity],
outputs=[research_status_text, research_plot_output, research_data_preview, research_advanced_plots]
)
# Wrapper for export functions to show file
def export_forecast_wrapper():
file, status = export_forecast_csv()
return gr.update(value=file, visible=file is not None), status
def export_metrics_wrapper():
file, status = export_metrics_csv()
return gr.update(value=file, visible=file is not None), status
def export_analysis_wrapper():
file, status = export_analysis_csv()
return gr.update(value=file, visible=file is not None), status
# Connect export button handlers
export_forecast_csv.click(
fn=export_forecast_wrapper,
inputs=[],
outputs=[export_file, export_status]
)
export_metrics_csv.click(
fn=export_metrics_wrapper,
inputs=[],
outputs=[export_file, export_status]
)
export_analysis_csv.click(
fn=export_analysis_wrapper,
inputs=[],
outputs=[export_file, export_status]
)
return app # Return the Gradio app object
# --- GRADIO APP LAUNCH ---
app = create_gradio_app()
app.launch()