subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import Callback
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils import logging
def collect_precision(tensor: torch.Tensor) -> Dict[str, str]:
"""Returns tensor's precision"""
if isinstance(tensor, torch.Tensor):
return {"Precision": str(tensor.dtype)}
else:
return {"Precision": "not-a-tensor"}
def collect_precision_and_shape(tensor: torch.Tensor) -> Dict[str, str]:
"""Returns tensor's shape & precision"""
if isinstance(tensor, torch.Tensor):
return {"Shape": str(tensor.shape), "Precision": str(tensor.dtype)}
else:
return {"Shape": "not-a-tensor", "Precision": "not-a-tensor"}
class ParameterDebugger(Callback):
"""
Debugging tool to help inspect parameters and gradients at any callback event.
This callback handles the boilerplate needed to iterate over the model parameters and gradients,
and applies user specified functions to them. These functions can be used to log attributes or
apply asserts on the param and grad tensors. Attributes are logged in a table, with a row for each parameter name.
Default behavior is to log the precision and shapes of each parameter and its gradient.
Args:
param_fn: Function to apply to model parameters. Can be used to apply assertions on the tensor,
or return a mapping of labels and values to log for each parameter.
grad_fn: Function to apply to model gradients. Can be used to apply assertions on the tensor,
or return a mapping of labels and values to log for each gradient.
log_on_hooks: PTL callback hook name or list of hook names on which to apply param_fn and grad_fn.
See `PTL docs <https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html#hooks>`_ for more info
on callback hooks. Note that some hooks that occur before the model is constructed are invalid.
Example:
>>> fn = lambda x: {"Norm": str(x.norm(2).item())}
>>> callback = ParameterDebugger(param_fn=fn, log_on_hooks=["on_train_start", "on_train_end"])
>>> trainer = Trainer(callbacks=[callback])
"""
def __init__(
self,
param_fn: Optional[Callable[[torch.Tensor], Optional[Dict[str, str]]]] = collect_precision_and_shape,
grad_fn: Optional[Callable[[torch.Tensor], Optional[Dict[str, str]]]] = collect_precision,
log_on_hooks: Union[List[str], str] = "on_train_start",
):
self.param_fn = param_fn
self.grad_fn = grad_fn
valid_hooks = set(
[
"teardown",
"on_fit_end",
"on_sanity_check_start",
"on_sanity_check_end",
"on_train_batch_start",
"on_train_batch_end",
"on_train_epoch_start",
"on_train_epoch_end",
"on_validation_epoch_start",
"on_validation_epoch_end",
"on_test_epoch_start",
"on_test_epoch_end",
"on_predict_epoch_start",
"on_predict_epoch_end",
"on_validation_batch_start",
"on_validation_batch_end",
"on_test_batch_start",
"on_test_batch_end",
"on_predict_batch_start",
"on_predict_batch_end",
"on_train_start",
"on_train_end",
"on_validation_start",
"on_validation_end",
"on_test_start",
"on_test_end",
"on_predict_start",
"on_predict_end",
"on_exception",
"on_save_checkpoint",
"on_load_checkpoint",
"on_before_backward",
"on_after_backward",
"on_before_optimizer_step",
"on_before_zero_grad",
]
)
if isinstance(log_on_hooks, str):
log_on_hooks = [log_on_hooks]
for hook_name in log_on_hooks:
assert hook_name in valid_hooks, (
"Hook {} supplied to log_on_hooks is not valid or " "can not be used. Valid hooks are {}"
).format(hook_name, valid_hooks)
setattr(self, hook_name, self._apply_user_funcs)
def _apply_user_funcs(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None:
"""
Iterate over model parameters, find gradient tensor, apply and collect outputs of
param_fn and grad_fn, and log outputs in a table.
"""
def find_grad_tensor(param: torch.Tensor) -> Optional[torch.Tensor]:
"""If using MCore optimizer, search the grad buckets for param's grad tensor."""
if not isinstance(getattr(pl_module, 'optim', None), MegatronOptimizerModule):
return param.grad
for buf in pl_module.buffers:
if param in buf.param_to_bucket:
return buf.param_to_bucket[param].grad_data
return None
names_col, params_output, grads_output = [], [], []
for param_name, param_tensor in pl_module.named_parameters():
grad_tensor = find_grad_tensor(param_tensor)
short_name = param_name.replace("module.", "").replace(".weight", "")
names_col.append(short_name)
for tensor, fn, out_col in zip(
[param_tensor, grad_tensor], [self.param_fn, self.grad_fn], [params_output, grads_output]
):
if fn is not None:
if tensor is not None:
out_col.append(fn(tensor))
else:
out_col.append({})
# get table column headers
param_keys, grad_keys = set([]), set([])
for output in params_output:
if output is not None:
param_keys.update(output.keys())
for output in grads_output:
if output is not None:
grad_keys.update(output.keys())
# create table only if there is something to print
if any(param_keys) or any(grad_keys):
from prettytable import PrettyTable
debug_table = PrettyTable()
debug_table.add_column("Parameter", names_col)
for prefix, keys, output_list in zip(
["Param ", "Grad "], [param_keys, grad_keys], [params_output, grads_output]
):
for k in keys:
col_to_log = []
for output in output_list:
if output is not None:
col_to_log.append(output.get(k, None))
else:
col_to_log.append(None)
if col_to_log != []:
debug_table.add_column(prefix + k, col_to_log)
debug_table.align = "l"
logging.info("\n" + debug_table.get_string())
class ModelTrainingStateCallback(Callback):
"""
Callback to detect model training state corruption after validation loop.
This callback monitors whether all model components maintain their training
state consistently before and after validation. Designed to catch issues
where some modules are left in eval() mode after PTL validation loop.
Args:
val_check_interval: Interval at which validation occurs. Default: 10.
strict: If True, raises an exception when corruption is detected. Default: False.
"""
def __init__(self, val_check_interval: int = 10, strict: bool = False):
self.val_check_interval = val_check_interval
self.strict = strict
self._pre_validation_state = None
self._expecting_check = False
def _get_training_state(self, trainer: "pl.Trainer") -> Dict[str, int]:
"""Get training/eval module counts from model chunks."""
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.utils.model_utils import unwrap_model
# Extract model chunks
model = trainer.model
if isinstance(model, MegatronParallel):
chunks = model.pipeline
model_chunks = chunks if isinstance(chunks, list) else [chunks]
else:
model_chunks = model if isinstance(model, list) else [model]
# Count training/eval modules
training_count = sum(sum(1 for m in unwrap_model(chunk).modules() if m.training) for chunk in model_chunks)
return {'training_modules': training_count, 'num_chunks': len(model_chunks)}
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
"""Monitor training state before/after validation."""
step = trainer.global_step
# Check state after validation
if self._expecting_check and self._pre_validation_state:
pre_count = self._pre_validation_state['training_modules']
post_count = self._get_training_state(trainer)['training_modules']
if pre_count != post_count:
msg = (
f"Model training state corruption detected! "
f"Before validation: {pre_count} training modules, "
f"after validation: {post_count} training modules (step {step})"
)
if self.strict:
raise RuntimeError(msg)
else:
logging.warning(msg)
self._pre_validation_state = None
self._expecting_check = False
# Capture state before next validation
if (step + 1) % self.val_check_interval == 0:
self._pre_validation_state = self._get_training_state(trainer)
self._expecting_check = True