|
|
|
|
|
""" |
|
|
GIFT-Eval Runner Script |
|
|
|
|
|
This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline. |
|
|
|
|
|
- Uses `src/gift_eval/data.py` for dataset handling. |
|
|
- Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference. |
|
|
- Loads a model from a local path or downloads from Hugging Face Hub. |
|
|
- Writes per-dataset CSV metrics to `output_dir` without creating plots. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
from src.gift_eval.constants import ALL_DATASETS |
|
|
from src.gift_eval.evaluate import evaluate_datasets |
|
|
from src.gift_eval.predictor import TimeSeriesPredictor |
|
|
from src.gift_eval.results import aggregate_results, write_results_to_disk |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
logging.getLogger("matplotlib").setLevel(logging.WARNING) |
|
|
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) |
|
|
logger = logging.getLogger("gift_eval_runner") |
|
|
|
|
|
|
|
|
def _expand_datasets_arg(datasets_arg: list[str] | str) -> list[str]: |
|
|
"""Expand dataset argument to list of dataset names.""" |
|
|
if isinstance(datasets_arg, str): |
|
|
if datasets_arg == "all": |
|
|
return list(ALL_DATASETS) |
|
|
datasets_list = [datasets_arg] |
|
|
else: |
|
|
datasets_list = datasets_arg |
|
|
if datasets_list and datasets_list[0] == "all": |
|
|
return list(ALL_DATASETS) |
|
|
|
|
|
for ds in datasets_list: |
|
|
if ds not in ALL_DATASETS: |
|
|
raise ValueError(f"Invalid dataset: {ds}. Use one of {ALL_DATASETS}") |
|
|
return datasets_list |
|
|
|
|
|
|
|
|
def run_evaluation( |
|
|
predictor: TimeSeriesPredictor, |
|
|
datasets_arg: list[str] | str, |
|
|
terms_arg: list[str], |
|
|
dataset_storage_path: str, |
|
|
max_windows_arg: int | None, |
|
|
batch_size_arg: int, |
|
|
max_context_length_arg: int | None, |
|
|
output_dir_arg: str, |
|
|
model_name_arg: str, |
|
|
after_each_dataset_flush: bool = True, |
|
|
) -> None: |
|
|
"""Run evaluation on specified datasets.""" |
|
|
datasets_to_run = _expand_datasets_arg(datasets_arg) |
|
|
results_root = Path(output_dir_arg) |
|
|
|
|
|
for ds_name in datasets_to_run: |
|
|
items = evaluate_datasets( |
|
|
predictor=predictor, |
|
|
dataset=ds_name, |
|
|
dataset_storage_path=dataset_storage_path, |
|
|
terms=terms_arg, |
|
|
max_windows=max_windows_arg, |
|
|
batch_size=batch_size_arg, |
|
|
max_context_length=max_context_length_arg, |
|
|
create_plots=False, |
|
|
max_plots_per_dataset=0, |
|
|
) |
|
|
write_results_to_disk( |
|
|
items=items, |
|
|
dataset_name=ds_name, |
|
|
output_dir=results_root, |
|
|
model_name=model_name_arg, |
|
|
create_plots=False, |
|
|
) |
|
|
if after_each_dataset_flush: |
|
|
logger.info("Flushed results for %s", ds_name) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main execution function.""" |
|
|
parser = argparse.ArgumentParser(description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets") |
|
|
|
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to a local model checkpoint. If not provided, will download from Hugging Face Hub.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--config_path", |
|
|
type=str, |
|
|
default="configs/example.yaml", |
|
|
help="Path to model config YAML (default: configs/example.yaml)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--hf_repo_id", |
|
|
type=str, |
|
|
default="AutoML-org/TempoPFN", |
|
|
help="Hugging Face Hub repo ID to download from (default: AutoML-org/TempoPFN)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--hf_filename", |
|
|
type=str, |
|
|
default="models/checkpoint_38M.pth", |
|
|
help="Filename of the checkpoint within the HF repo (default: models/checkpoint_38M.pth)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--datasets", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=["all"], |
|
|
help='List of dataset names or ["all"] (default: all)', |
|
|
) |
|
|
parser.add_argument( |
|
|
"--terms", |
|
|
type=str, |
|
|
nargs="+", |
|
|
default=["short", "medium", "long"], |
|
|
help="Prediction terms to evaluate (default: short medium long)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset_storage_path", |
|
|
type=str, |
|
|
default="/work/dlclarge2/moroshav-GiftEvalPretrain/gift_eval", |
|
|
|
|
|
help="Path to the root of the gift eval datasets storage directory", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_windows", |
|
|
type=int, |
|
|
default=20, |
|
|
help="Maximum number of windows to use for evaluation (default: 20)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--batch_size", |
|
|
type=int, |
|
|
default=64, |
|
|
help="Batch size for inference (default: 128)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_context_length", |
|
|
type=int, |
|
|
default=3072, |
|
|
help="Maximum context length (default: 3072)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="gift_eval_results", |
|
|
help="Output directory for results (default: gift_eval_results)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_name", |
|
|
type=str, |
|
|
default="TempoPFN", |
|
|
help="Model name identifier for results (default: TempoPFN)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--no_flush", |
|
|
action="store_true", |
|
|
help="Disable flushing results after each dataset", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
config_path = Path(args.config_path) |
|
|
output_dir = Path(args.output_dir) |
|
|
|
|
|
|
|
|
resolved_model_path = None |
|
|
if args.model_path: |
|
|
logger.info("Using local model checkpoint from --model_path: %s", args.model_path) |
|
|
resolved_model_path = args.model_path |
|
|
else: |
|
|
logger.info("Downloading model from Hugging Face Hub...") |
|
|
logger.info(" Repo: %s", args.hf_repo_id) |
|
|
logger.info(" File: %s", args.hf_filename) |
|
|
try: |
|
|
resolved_model_path = hf_hub_download( |
|
|
repo_id=args.hf_repo_id, |
|
|
filename=args.hf_filename, |
|
|
) |
|
|
logger.info("Download complete. Model path: %s", resolved_model_path) |
|
|
except Exception as e: |
|
|
logger.error("Failed to download model from Hugging Face Hub: %s", e) |
|
|
raise e |
|
|
|
|
|
if not resolved_model_path or not Path(resolved_model_path).exists(): |
|
|
raise FileNotFoundError( |
|
|
f"Could not resolve model checkpoint. Checked local path: {args.model_path} and HF download." |
|
|
) |
|
|
|
|
|
|
|
|
if not config_path.exists(): |
|
|
raise FileNotFoundError(f"Config not found: {config_path}") |
|
|
|
|
|
logger.info("Loading predictor from checkpoint: %s", resolved_model_path) |
|
|
predictor = TimeSeriesPredictor.from_paths( |
|
|
model_path=resolved_model_path, |
|
|
config_path=str(config_path), |
|
|
ds_prediction_length=1, |
|
|
ds_freq="D", |
|
|
batch_size=args.batch_size, |
|
|
max_context_length=args.max_context_length, |
|
|
) |
|
|
|
|
|
logger.info("Starting evaluation...") |
|
|
logger.info(" Datasets: %s", args.datasets) |
|
|
logger.info(" Terms: %s", args.terms) |
|
|
logger.info(" Output directory: %s", output_dir) |
|
|
|
|
|
|
|
|
run_evaluation( |
|
|
predictor=predictor, |
|
|
datasets_arg=args.datasets, |
|
|
terms_arg=args.terms, |
|
|
dataset_storage_path=args.dataset_storage_path, |
|
|
max_windows_arg=args.max_windows, |
|
|
batch_size_arg=args.batch_size, |
|
|
max_context_length_arg=args.max_context_length, |
|
|
output_dir_arg=str(output_dir), |
|
|
model_name_arg=args.model_name, |
|
|
after_each_dataset_flush=not args.no_flush, |
|
|
) |
|
|
|
|
|
logger.info("Evaluation complete. See results under: %s", output_dir) |
|
|
|
|
|
|
|
|
logger.info("Aggregating results from all datasets...") |
|
|
combined_df = aggregate_results(result_root_dir=output_dir) |
|
|
|
|
|
if combined_df is not None: |
|
|
logger.info( |
|
|
"Successfully created aggregated results file: %s/all_results.csv", |
|
|
output_dir, |
|
|
) |
|
|
else: |
|
|
logger.warning("No results to aggregate. Check that evaluation completed successfully.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|