# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Union import lightning.pytorch as pl import torch from lightning.pytorch.callbacks import Callback from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule from nemo.utils import logging def collect_precision(tensor: torch.Tensor) -> Dict[str, str]: """Returns tensor's precision""" if isinstance(tensor, torch.Tensor): return {"Precision": str(tensor.dtype)} else: return {"Precision": "not-a-tensor"} def collect_precision_and_shape(tensor: torch.Tensor) -> Dict[str, str]: """Returns tensor's shape & precision""" if isinstance(tensor, torch.Tensor): return {"Shape": str(tensor.shape), "Precision": str(tensor.dtype)} else: return {"Shape": "not-a-tensor", "Precision": "not-a-tensor"} class ParameterDebugger(Callback): """ Debugging tool to help inspect parameters and gradients at any callback event. This callback handles the boilerplate needed to iterate over the model parameters and gradients, and applies user specified functions to them. These functions can be used to log attributes or apply asserts on the param and grad tensors. Attributes are logged in a table, with a row for each parameter name. Default behavior is to log the precision and shapes of each parameter and its gradient. Args: param_fn: Function to apply to model parameters. Can be used to apply assertions on the tensor, or return a mapping of labels and values to log for each parameter. grad_fn: Function to apply to model gradients. Can be used to apply assertions on the tensor, or return a mapping of labels and values to log for each gradient. log_on_hooks: PTL callback hook name or list of hook names on which to apply param_fn and grad_fn. See `PTL docs `_ for more info on callback hooks. Note that some hooks that occur before the model is constructed are invalid. Example: >>> fn = lambda x: {"Norm": str(x.norm(2).item())} >>> callback = ParameterDebugger(param_fn=fn, log_on_hooks=["on_train_start", "on_train_end"]) >>> trainer = Trainer(callbacks=[callback]) """ def __init__( self, param_fn: Optional[Callable[[torch.Tensor], Optional[Dict[str, str]]]] = collect_precision_and_shape, grad_fn: Optional[Callable[[torch.Tensor], Optional[Dict[str, str]]]] = collect_precision, log_on_hooks: Union[List[str], str] = "on_train_start", ): self.param_fn = param_fn self.grad_fn = grad_fn valid_hooks = set( [ "teardown", "on_fit_end", "on_sanity_check_start", "on_sanity_check_end", "on_train_batch_start", "on_train_batch_end", "on_train_epoch_start", "on_train_epoch_end", "on_validation_epoch_start", "on_validation_epoch_end", "on_test_epoch_start", "on_test_epoch_end", "on_predict_epoch_start", "on_predict_epoch_end", "on_validation_batch_start", "on_validation_batch_end", "on_test_batch_start", "on_test_batch_end", "on_predict_batch_start", "on_predict_batch_end", "on_train_start", "on_train_end", "on_validation_start", "on_validation_end", "on_test_start", "on_test_end", "on_predict_start", "on_predict_end", "on_exception", "on_save_checkpoint", "on_load_checkpoint", "on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad", ] ) if isinstance(log_on_hooks, str): log_on_hooks = [log_on_hooks] for hook_name in log_on_hooks: assert hook_name in valid_hooks, ( "Hook {} supplied to log_on_hooks is not valid or " "can not be used. Valid hooks are {}" ).format(hook_name, valid_hooks) setattr(self, hook_name, self._apply_user_funcs) def _apply_user_funcs(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None: """ Iterate over model parameters, find gradient tensor, apply and collect outputs of param_fn and grad_fn, and log outputs in a table. """ def find_grad_tensor(param: torch.Tensor) -> Optional[torch.Tensor]: """If using MCore optimizer, search the grad buckets for param's grad tensor.""" if not isinstance(getattr(pl_module, 'optim', None), MegatronOptimizerModule): return param.grad for buf in pl_module.buffers: if param in buf.param_to_bucket: return buf.param_to_bucket[param].grad_data return None names_col, params_output, grads_output = [], [], [] for param_name, param_tensor in pl_module.named_parameters(): grad_tensor = find_grad_tensor(param_tensor) short_name = param_name.replace("module.", "").replace(".weight", "") names_col.append(short_name) for tensor, fn, out_col in zip( [param_tensor, grad_tensor], [self.param_fn, self.grad_fn], [params_output, grads_output] ): if fn is not None: if tensor is not None: out_col.append(fn(tensor)) else: out_col.append({}) # get table column headers param_keys, grad_keys = set([]), set([]) for output in params_output: if output is not None: param_keys.update(output.keys()) for output in grads_output: if output is not None: grad_keys.update(output.keys()) # create table only if there is something to print if any(param_keys) or any(grad_keys): from prettytable import PrettyTable debug_table = PrettyTable() debug_table.add_column("Parameter", names_col) for prefix, keys, output_list in zip( ["Param ", "Grad "], [param_keys, grad_keys], [params_output, grads_output] ): for k in keys: col_to_log = [] for output in output_list: if output is not None: col_to_log.append(output.get(k, None)) else: col_to_log.append(None) if col_to_log != []: debug_table.add_column(prefix + k, col_to_log) debug_table.align = "l" logging.info("\n" + debug_table.get_string()) class ModelTrainingStateCallback(Callback): """ Callback to detect model training state corruption after validation loop. This callback monitors whether all model components maintain their training state consistently before and after validation. Designed to catch issues where some modules are left in eval() mode after PTL validation loop. Args: val_check_interval: Interval at which validation occurs. Default: 10. strict: If True, raises an exception when corruption is detected. Default: False. """ def __init__(self, val_check_interval: int = 10, strict: bool = False): self.val_check_interval = val_check_interval self.strict = strict self._pre_validation_state = None self._expecting_check = False def _get_training_state(self, trainer: "pl.Trainer") -> Dict[str, int]: """Get training/eval module counts from model chunks.""" from nemo.lightning.megatron_parallel import MegatronParallel from nemo.utils.model_utils import unwrap_model # Extract model chunks model = trainer.model if isinstance(model, MegatronParallel): chunks = model.pipeline model_chunks = chunks if isinstance(chunks, list) else [chunks] else: model_chunks = model if isinstance(model, list) else [model] # Count training/eval modules training_count = sum(sum(1 for m in unwrap_model(chunk).modules() if m.training) for chunk in model_chunks) return {'training_modules': training_count, 'num_chunks': len(model_chunks)} def on_train_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int ) -> None: """Monitor training state before/after validation.""" step = trainer.global_step # Check state after validation if self._expecting_check and self._pre_validation_state: pre_count = self._pre_validation_state['training_modules'] post_count = self._get_training_state(trainer)['training_modules'] if pre_count != post_count: msg = ( f"Model training state corruption detected! " f"Before validation: {pre_count} training modules, " f"after validation: {post_count} training modules (step {step})" ) if self.strict: raise RuntimeError(msg) else: logging.warning(msg) self._pre_validation_state = None self._expecting_check = False # Capture state before next validation if (step + 1) % self.val_check_interval == 0: self._pre_validation_state = self._get_training_state(trainer) self._expecting_check = True