# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from dataclasses import dataclass from pathlib import Path, PosixPath, WindowsPath from typing import Optional, Union import lightning.fabric as fl import lightning.pytorch as pl from nemo.lightning import io from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME from nemo.lightning.pytorch.strategies.utils import RestoreConfig from nemo.utils import logging from nemo.utils.app_state import AppState from nemo.utils.model_utils import uninject_model_parallel_rank from nemo.utils.msc_utils import import_multistorageclient, is_multistorageclient_url # Dynamically inherit from the correct Path subclass based on the operating system. if os.name == "nt": BasePath = WindowsPath else: BasePath = PosixPath def _try_restore_tokenizer(model, ckpt_path): from nemo.collections.common.tokenizers import TokenizerSpec from nemo.lightning.io import load_context try: tokenizer = load_context(ckpt_path, "model.tokenizer") except ValueError as e: logging.warning( f"Encountered error while trying to restore tokenizer. Tokenizer is not restored. " f"Original error: {e}" ) return model if isinstance(tokenizer, TokenizerSpec): model.tokenizer = tokenizer model.__io__.tokenizer = tokenizer.__io__ else: # Ignore if the ckpt doesn't have a tokenizer. type(tokenizer)==TrainerContext in this case. logging.warning("Checkpoint does not have model.tokenizer field. Tokenizer is not restored.") return model @dataclass(kw_only=True) class AutoResume: """Class that handles the logic for setting checkpoint paths and restoring from checkpoints in NeMo. Attributes: restore_config (Optional[RestoreConfig]): Optional config for selectively restoring specific parts like model weights, optimizer states, etc. If the config contains a path from HF or another non-NeMo checkpoint format, the checkpoint will be automatically converted to a NeMo compatible format. resume_from_folder or the run's log_dir takes precedence over restore_config. resume_from_directory (str): Path to the checkpointing directory to restore from. resume_from_path (str): Path to a specific checkpoint to restore from. resume_if_exists (bool): Whether this experiment is resuming from a previous run. If True, it sets trainer._checkpoint_connector._ckpt_path so that the trainer should auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. resume_past_end (bool): By default, AutoResume throws an error if resume_if_exists is True and a checkpoint matching ``*end.ckpt`` indicating a previous training run fully completed. Setting resume_past_end=True disables this behavior and loads the last checkpoint. resume_ignore_no_checkpoint (bool): AutoResume throws an error if resume_if_exists is True and no checkpoint could be found. Setting resume_ignore_no_checkpoint=True disables this behavior, in which case exp_manager will print a message and continue without restoring. """ restore_config: Optional[RestoreConfig] = None resume_from_directory: Optional[str] = None resume_from_path: Optional[str] = None resume_if_exists: bool = False resume_past_end: bool = False resume_ignore_no_checkpoint: bool = False WEIGHTS_PATH = "weights" def get_weights_path(self, path) -> Path: """Returns the path to the weights directory within the specified path. Args: path: The checkpoint directory path Returns: Path: A Path object pointing to the weights directory """ return path / self.WEIGHTS_PATH def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None): """Sets up checkpoint restoration for the Pytorch Lightning trainer. This method configures the trainer with the appropriate checkpoint path for resuming training and handles loading model artifacts like tokenizers when specified. Args: trainer: The PyTorch Lightning trainer or Fabric instance model: Optional model instance to load artifacts into Raises: NotImplementedError: If trainer is a Fabric instance (not yet supported) """ if isinstance(trainer, fl.Fabric): raise NotImplementedError("Fabric is not supported yet.") trainer_ckpt_path = self.get_trainer_ckpt_path(model) if trainer_ckpt_path: trainer.ckpt_path = trainer_ckpt_path trainer.checkpoint_callback.last_model_path = trainer_ckpt_path # Load artifacts if getattr(self.restore_config, "load_artifacts", False): if isinstance(trainer_ckpt_path, AdapterPath): # load tokenizer from the base model during peft resume, in case the first peft checkpoint # is deleted before the current peft checkpoint is saved context_path = trainer_ckpt_path.base_model_path / "context" if not context_path.exists(): context_path = trainer_ckpt_path.base_model_path else: context_path = self.get_context_path(model) model = _try_restore_tokenizer(model, context_path) elif self.restore_config: new_path = self._extract_path( path=self.restore_config.path, ) assert not isinstance(new_path, AdapterPath), "AdapterPath is not supported for restore_config" self.restore_config.path = str(new_path) trainer.strategy.restore_config = self.restore_config # Load artifacts if self.restore_config.load_artifacts: if isinstance(new_path, AdapterPath): context_path = Path(new_path.base_model_path) / "context" else: context_path = new_path / "context" if not context_path.is_dir(): context_path = new_path _try_restore_tokenizer(model, context_path) def _extract_path(self, path: str) -> BasePath: if "://" in path: assert path.startswith("nemo://"), "Only NeMo based paths starting with nemo:// are currently supported." _, _path = path.split("://") new_path = os.path.join(NEMO_MODELS_CACHE, _path) else: new_path = path if isinstance(new_path, str): new_path = Path(new_path) return new_path def _get_base_model_path_for_adapter(self, adapter_meta_path, model): with open(adapter_meta_path, "r") as f: metadata = json.load(f) # Use the model_ckpt_path from metadata directly base_model_path = Path(metadata["model_ckpt_path"]) # If base_model_path points to a specific checkpoint file, use its parent directory if not base_model_path.is_dir() and base_model_path.exists(): base_model_path = base_model_path.parent return base_model_path def _find_trainer_ckpt_path(self) -> Optional[Path]: from nemo.utils.exp_manager import NotFoundError, _filter_out_unfinished_checkpoints app_state = AppState() log_dir = app_state.log_dir checkpoint = None # Use /checkpoints/ unless `dirpath` is set if self.resume_from_directory: if is_multistorageclient_url(self.resume_from_directory): msc = import_multistorageclient() checkpoint_dir = msc.Path(self.resume_from_directory) else: checkpoint_dir = Path(self.resume_from_directory) elif log_dir is not None: checkpoint_dir = Path(Path(log_dir) / "checkpoints") else: # ie. if log_dir is None return None # when using distributed checkpointing, checkpoint_dir is a directory of directories # we check for this here dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()] end_dist_checkpoints = [d for d in dist_checkpoints if d.match("*end")] last_dist_checkpoints = [d for d in dist_checkpoints if d.match("*last")] end_chkpt_cnt = len(end_dist_checkpoints) end_checkpoints = _filter_out_unfinished_checkpoints(end_dist_checkpoints) finished_end_chkpt_cnt = len(end_checkpoints) if end_chkpt_cnt > 0 and finished_end_chkpt_cnt == 0: raise ValueError( "End checkpoint is unfinished and cannot be used to resume the training." " Please remove the checkpoint manually to avoid unexpected cosequences, such as" " restarting from scratch." ) last_chkpt_cnt = len(last_dist_checkpoints) last_checkpoints = _filter_out_unfinished_checkpoints(last_dist_checkpoints) finished_last_chkpt_cnt = len(last_checkpoints) if last_chkpt_cnt > 0 and finished_last_chkpt_cnt == 0: raise ValueError( "Last checkpoint is unfinished and cannot be used to resume the training." " Please remove the checkpoint manually to avoid unexpected cosequences, such as" " restarting from scratch. Hint: Iteration number can be added to the checkpoint name pattern" " to maximize chance that there is at least one finished last checkpoint to resume from." ) if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0): if self.resume_ignore_no_checkpoint: message = ( f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir " f":{checkpoint_dir}. " ) if not self.restore_config: logging.warning(message + "Training from scratch.") else: logging.info(message + "Trying to resume from RestoreConfig.") else: if self.restore_config: # resume_if_exists is True but run is not resumable. Do not fail and try to do selective restore # later instead. return None else: raise NotFoundError( f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir " f":{checkpoint_dir}. Cannot resume." ) elif len(end_checkpoints) > 0: if not self.resume_past_end: raise ValueError( f"Found {end_checkpoints[0]} indicating that the last training run has already completed." ) if len(end_checkpoints) > 1: if "mp_rank" in str(end_checkpoints[0]): checkpoint = end_checkpoints[0] else: raise ValueError(f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt.") elif len(last_checkpoints) > 1: if any([s for s in ["mp_rank", "tp_rank", "fsdp_shard"] if s in str(last_checkpoints[0])]): checkpoint = last_checkpoints[0] checkpoint = uninject_model_parallel_rank(checkpoint) else: # Select the checkpoint with the latest modified time checkpoint = sorted(last_checkpoints, key=lambda pth: pth.lstat().st_mtime, reverse=True)[0] logging.warning( f"Multiple checkpoints {last_checkpoints} matches *last.ckpt. Selecting one with the latest " f"modified time." ) else: checkpoint = last_checkpoints[0] return checkpoint def get_context_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]: """Retrieves the path to the context directory of a checkpoint. The context directory contains serialized objects like tokenizers. This method handles both cases where the context is directly in the checkpoint directory or in a subdirectory called "context". Args: model: Optional model instance Returns: Optional[Path]: Path to the context directory if found, None otherwise """ checkpoint = None app_state = AppState() app_state.restore = self.resume_if_exists if self.resume_if_exists: checkpoint = self._find_trainer_ckpt_path() if checkpoint: maybe_context_path = checkpoint / "context" if maybe_context_path.is_dir(): checkpoint = maybe_context_path return checkpoint def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]: """Resolves the path to a checkpoint for resuming training. This method handles various checkpoint sources with the following priority: 1. Explicit path specified in resume_from_path 2. Automatic discovery in the checkpoint directory when resume_if_exists=True For adapter checkpoints (PEFT), it also retrieves the base model path from metadata. Args: model: Optional model instance Returns: Optional[Path]: Path to the checkpoint if found, or AdapterPath for PEFT checkpoints, or None if no checkpoint is found or needed """ if self.resume_from_path: if is_multistorageclient_url(self.resume_from_path): msc = import_multistorageclient() resume_from_path = msc.Path(self.resume_from_path) else: resume_from_path = Path(self.resume_from_path) maybe_weights_path = self.get_weights_path(resume_from_path) if maybe_weights_path.is_dir(): adapter_meta_path = maybe_weights_path / ADAPTER_META_FILENAME if adapter_meta_path.exists(): # the resume_from_path is an adapter checkpoint base_model_path = self._get_base_model_path_for_adapter(adapter_meta_path, model) return AdapterPath(Path(self.resume_from_path), base_model_path=base_model_path) else: # the resume_from_path is not PEFT checkpoint return maybe_weights_path else: return self.resume_from_path checkpoint = None app_state = AppState() app_state.restore = self.resume_if_exists if self.resume_if_exists: checkpoint = self._find_trainer_ckpt_path() if checkpoint: maybe_weights_path = self.get_weights_path(checkpoint) if maybe_weights_path.is_dir(): checkpoint = maybe_weights_path if checkpoint: adapter_meta_path = checkpoint / ADAPTER_META_FILENAME if adapter_meta_path.exists(): base_model_path = self._get_base_model_path_for_adapter(adapter_meta_path, model) return AdapterPath(checkpoint, base_model_path=base_model_path) else: return checkpoint return None class AdapterPath(BasePath): """Path object for adapter paths which include a field for the base model the adapters are trained on to facilitate model loading.""" base_model_path: Optional[Path] def __new__(cls, *args, base_model_path: Optional[Path] = None, **kwargs): output = super().__new__(cls, *args, **kwargs) output.base_model_path = base_model_path return output def __repr__(self): return "{}({!r}, base_model_path={})".format(self.__class__.__name__, self.as_posix(), self.base_model_path)