tempoPFN / examples /gift_eval /gift_eval_runner.py
Vladyslav Moroshan
Apply ruff formatting
96e1a32
#!/usr/bin/env python
"""
GIFT-Eval Runner Script
This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline.
- Uses `src/gift_eval/data.py` for dataset handling.
- Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference.
- Loads a model from a local path or downloads from Hugging Face Hub.
- Writes per-dataset CSV metrics to `output_dir` without creating plots.
"""
import argparse
import logging
from pathlib import Path
from huggingface_hub import hf_hub_download
from src.gift_eval.constants import ALL_DATASETS
from src.gift_eval.evaluate import evaluate_datasets
from src.gift_eval.predictor import TimeSeriesPredictor
from src.gift_eval.results import aggregate_results, write_results_to_disk
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
logger = logging.getLogger("gift_eval_runner")
def _expand_datasets_arg(datasets_arg: list[str] | str) -> list[str]:
"""Expand dataset argument to list of dataset names."""
if isinstance(datasets_arg, str):
if datasets_arg == "all":
return list(ALL_DATASETS)
datasets_list = [datasets_arg]
else:
datasets_list = datasets_arg
if datasets_list and datasets_list[0] == "all":
return list(ALL_DATASETS)
for ds in datasets_list:
if ds not in ALL_DATASETS:
raise ValueError(f"Invalid dataset: {ds}. Use one of {ALL_DATASETS}")
return datasets_list
def run_evaluation(
predictor: TimeSeriesPredictor,
datasets_arg: list[str] | str,
terms_arg: list[str],
dataset_storage_path: str,
max_windows_arg: int | None,
batch_size_arg: int,
max_context_length_arg: int | None,
output_dir_arg: str,
model_name_arg: str,
after_each_dataset_flush: bool = True,
) -> None:
"""Run evaluation on specified datasets."""
datasets_to_run = _expand_datasets_arg(datasets_arg)
results_root = Path(output_dir_arg)
for ds_name in datasets_to_run:
items = evaluate_datasets(
predictor=predictor,
dataset=ds_name,
dataset_storage_path=dataset_storage_path,
terms=terms_arg,
max_windows=max_windows_arg,
batch_size=batch_size_arg,
max_context_length=max_context_length_arg,
create_plots=False,
max_plots_per_dataset=0,
)
write_results_to_disk(
items=items,
dataset_name=ds_name,
output_dir=results_root,
model_name=model_name_arg,
create_plots=False,
)
if after_each_dataset_flush:
logger.info("Flushed results for %s", ds_name)
def main():
"""Main execution function."""
parser = argparse.ArgumentParser(description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets")
parser.add_argument(
"--model_path",
type=str,
default=None,
help="Path to a local model checkpoint. If not provided, will download from Hugging Face Hub.",
)
parser.add_argument(
"--config_path",
type=str,
default="configs/example.yaml",
help="Path to model config YAML (default: configs/example.yaml)",
)
parser.add_argument(
"--hf_repo_id",
type=str,
default="AutoML-org/TempoPFN",
help="Hugging Face Hub repo ID to download from (default: AutoML-org/TempoPFN)",
)
parser.add_argument(
"--hf_filename",
type=str,
default="models/checkpoint_38M.pth",
help="Filename of the checkpoint within the HF repo (default: models/checkpoint_38M.pth)",
)
# Dataset configuration
parser.add_argument(
"--datasets",
type=str,
nargs="+",
default=["all"],
help='List of dataset names or ["all"] (default: all)',
)
parser.add_argument(
"--terms",
type=str,
nargs="+",
default=["short", "medium", "long"],
help="Prediction terms to evaluate (default: short medium long)",
)
parser.add_argument(
"--dataset_storage_path",
type=str,
default="/work/dlclarge2/moroshav-GiftEvalPretrain/gift_eval",
# required=True,
help="Path to the root of the gift eval datasets storage directory",
)
parser.add_argument(
"--max_windows",
type=int,
default=20,
help="Maximum number of windows to use for evaluation (default: 20)",
)
# Inference configuration
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Batch size for inference (default: 128)",
)
parser.add_argument(
"--max_context_length",
type=int,
default=3072,
help="Maximum context length (default: 3072)",
)
# Output configuration
parser.add_argument(
"--output_dir",
type=str,
default="gift_eval_results",
help="Output directory for results (default: gift_eval_results)",
)
parser.add_argument(
"--model_name",
type=str,
default="TempoPFN",
help="Model name identifier for results (default: TempoPFN)",
)
parser.add_argument(
"--no_flush",
action="store_true",
help="Disable flushing results after each dataset",
)
args = parser.parse_args()
# Resolve paths
config_path = Path(args.config_path)
output_dir = Path(args.output_dir)
# --- Determine Model Path (Updated) ---
resolved_model_path = None
if args.model_path:
logger.info("Using local model checkpoint from --model_path: %s", args.model_path)
resolved_model_path = args.model_path
else:
logger.info("Downloading model from Hugging Face Hub...")
logger.info(" Repo: %s", args.hf_repo_id)
logger.info(" File: %s", args.hf_filename)
try:
resolved_model_path = hf_hub_download(
repo_id=args.hf_repo_id,
filename=args.hf_filename,
)
logger.info("Download complete. Model path: %s", resolved_model_path)
except Exception as e:
logger.error("Failed to download model from Hugging Face Hub: %s", e)
raise e
if not resolved_model_path or not Path(resolved_model_path).exists():
raise FileNotFoundError(
f"Could not resolve model checkpoint. Checked local path: {args.model_path} and HF download."
)
# --- End of Updated Section ---
if not config_path.exists():
raise FileNotFoundError(f"Config not found: {config_path}")
logger.info("Loading predictor from checkpoint: %s", resolved_model_path)
predictor = TimeSeriesPredictor.from_paths(
model_path=resolved_model_path,
config_path=str(config_path),
ds_prediction_length=1, # placeholder; set per dataset
ds_freq="D", # placeholder; set per dataset
batch_size=args.batch_size,
max_context_length=args.max_context_length,
)
logger.info("Starting evaluation...")
logger.info(" Datasets: %s", args.datasets)
logger.info(" Terms: %s", args.terms)
logger.info(" Output directory: %s", output_dir)
# Run evaluation
run_evaluation(
predictor=predictor,
datasets_arg=args.datasets,
terms_arg=args.terms,
dataset_storage_path=args.dataset_storage_path,
max_windows_arg=args.max_windows,
batch_size_arg=args.batch_size,
max_context_length_arg=args.max_context_length,
output_dir_arg=str(output_dir),
model_name_arg=args.model_name,
after_each_dataset_flush=not args.no_flush,
)
logger.info("Evaluation complete. See results under: %s", output_dir)
# Aggregate all results into a single CSV file
logger.info("Aggregating results from all datasets...")
combined_df = aggregate_results(result_root_dir=output_dir)
if combined_df is not None:
logger.info(
"Successfully created aggregated results file: %s/all_results.csv",
output_dir,
)
else:
logger.warning("No results to aggregate. Check that evaluation completed successfully.")
if __name__ == "__main__":
main()