subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import Generator, Literal, TypeVar, Union
import torch
from lightning.pytorch.plugins.precision import Precision
from torch.nn import Module
from torch.optim import Optimizer
from nemo.utils import logging
AnyT = TypeVar("AnyT")
def get_optim_config(optimizer: Optimizer):
"""Extract optimizer configurations from a Megatron optimizer.
Args:
optimizer: A torch.optim.Optimizer instance
Yields:
Optimizer configurations
"""
extract_config = lambda x: x.config
try:
from megatron.core.optimizer import ChainedOptimizer
if isinstance(optimizer.mcore_optimizer, ChainedOptimizer):
opts = optimizer.mcore_optimizer.chained_optimizers
else:
opts = [optimizer.mcore_optimizer]
yield from map(extract_config, opts)
except:
raise ValueError("Failed to extract optimizer config from module.")
@dataclass
class DtypeConfig:
"""Configuration class for mixed precision training settings.
Contains settings for FP32/FP16/BF16 training, FP8 training.
"""
fp32: bool = False
fp16: bool = False
bf16: bool = False
params_dtype: torch.dtype = None
pipeline_dtype: torch.dtype = None
autocast_dtype: torch.dtype = None
autocast_enabled: bool = False
grad_reduce_in_fp32: bool = True
# fp8 related
fp8: str = None
fp8_recipe: str = "delayed"
# fp4 related
fp4: str = None
fp4_recipe: str = "nvfp4"
first_last_layers_bf16: bool = False
fp8_margin: int = 0
fp8_amax_history_len: int = 1
fp8_amax_compute_algo: str = "most_recent"
fp8_wgrad: bool = True
fp8_dot_product_attention: bool = False
fp8_multi_head_attention: bool = False
fp8_param: bool = True
fp8_param_gather: bool = True
# FP16 Loss scaling
loss_scale: float = (None,)
initial_loss_scale: float = (None,)
min_loss_scale: float = (None,)
loss_scale_window: float = (None,)
hysteresis: float = (None,)
num_layers_at_start_in_bf16: int = 0
num_layers_at_end_in_bf16: int = 0
reuse_grad_buf_for_mxfp8_param_ag: bool = False
class MegatronMixedPrecision(Precision):
"""Plugin for mixed precision training with Megatron models.
Handles conversion of model parameters and inputs/outputs between different precisions,
and manages mixed precision training settings.
"""
def __init__(
self,
precision: Literal["16-mixed", "bf16-mixed", "32"],
params_dtype: torch.dtype = None,
pipeline_dtype: torch.dtype = None,
autocast_dtype: torch.dtype = None,
autocast_enabled: bool = False,
grad_reduce_in_fp32: bool = True,
# fp8 related,
fp8: str = None,
fp8_recipe: str = "delayed", # "tensorwise", "delayed", "mxfp8" (for Blackwell only)
first_last_layers_bf16: bool = False,
fp8_margin: int = 0,
fp8_amax_history_len: int = 1,
fp8_amax_compute_algo: str = "most_recent",
fp8_wgrad: bool = True,
fp8_dot_product_attention: bool = False,
fp8_multi_head_attention: bool = False,
fp8_params: bool = None,
fp8_param_gather: bool = None,
# fp4 related
fp4: str = None,
fp4_recipe: str = "nvfp4",
fp16_loss_scale: float = None,
fp16_initial_loss_scale: float = 4294967296,
fp16_min_loss_scale: float = 1.0,
fp16_loss_scale_window: int = 1000,
fp16_hysteresis: int = 2,
num_layers_at_start_in_bf16: int = 0,
num_layers_at_end_in_bf16: int = 0,
reuse_grad_buf_for_mxfp8_param_ag: bool = False,
) -> None:
if fp8_params is not None:
logging.warning(
"fp8_params is deprecated and will be removed in a future release, use fp8_param_gather instead"
)
if fp8_param_gather is not None and fp8_param_gather != fp8_params:
raise ValueError(
"Getting conflicting values for fp8_params and fp8_param_gather. Please only set fp8_param_gather."
)
fp8_param_gather = fp8_params
elif fp8_param_gather is None:
fp8_param_gather = False
if isinstance(precision, int):
precision = str(precision)
dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32
self.dtype_config = DtypeConfig(
fp32=precision in ['fp32', '32'],
fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'],
bf16=precision in ['bf16', 'bf16-mixed'],
params_dtype=params_dtype or torch.float32,
pipeline_dtype=pipeline_dtype or dtype,
autocast_dtype=autocast_dtype or dtype,
autocast_enabled=autocast_enabled,
grad_reduce_in_fp32=grad_reduce_in_fp32,
fp8=fp8,
fp8_recipe=fp8_recipe,
first_last_layers_bf16=first_last_layers_bf16,
fp8_margin=fp8_margin,
fp8_amax_history_len=fp8_amax_history_len,
fp8_amax_compute_algo=fp8_amax_compute_algo,
fp8_wgrad=fp8_wgrad,
fp8_dot_product_attention=fp8_dot_product_attention,
fp8_multi_head_attention=fp8_multi_head_attention,
fp8_param=fp8_param_gather,
fp8_param_gather=fp8_param_gather,
fp4=fp4,
fp4_recipe=fp4_recipe,
num_layers_at_start_in_bf16=num_layers_at_start_in_bf16,
num_layers_at_end_in_bf16=num_layers_at_end_in_bf16,
reuse_grad_buf_for_mxfp8_param_ag=reuse_grad_buf_for_mxfp8_param_ag,
# fp16 loss scale
loss_scale=fp16_loss_scale,
initial_loss_scale=fp16_initial_loss_scale,
min_loss_scale=fp16_min_loss_scale,
loss_scale_window=fp16_loss_scale_window,
hysteresis=fp16_hysteresis,
)
super().__init__()
if self.dtype_config.fp16:
self.precision = "16-mixed"
elif self.dtype_config.bf16:
self.precision = "bf16-mixed"
else:
self.precision = "32-true"
def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
"""
from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_model_config
if self.dtype_config.fp16 or self.dtype_config.bf16:
# Patch config options
config = get_model_config(module.module)
config.fp16 = self.dtype_config.fp16
config.bf16 = self.dtype_config.bf16
# Avoid rewrapping the module if it's already of type Float16Module
if hasattr(module, "module"):
if not isinstance(module.module, Float16Module):
module.module = Float16Module(config, module.module)
elif not isinstance(module, Float16Module):
module = Float16Module(config, module)
return module
def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Convert the optimizer parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
"""
for optim_config in get_optim_config(optimizer):
assert optim_config.bf16 == self.dtype_config.bf16, "BF16 model/optim config mismatch"
assert optim_config.fp16 == self.dtype_config.fp16, "FP16 model/optim config mismatch"
return optimizer
def convert_input(self, data: AnyT) -> AnyT:
"""Convert model inputs (forward) to the floating point precision type of this plugin.
Note: MegatronStrategy will take care of only doing this when:
parallel_state.is_pipeline_first_stage()
"""
return data
def convert_output(self, data: AnyT) -> AnyT:
"""Convert outputs to the floating point precision type expected after model's forward.
Note: MegatronStrategy will take care of only doing this when:
parallel_state.is_pipeline_last_stage()
"""
return data
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""No explicit precision casting. Inputs are supposed to be manually casted."""
try:
yield
finally:
pass
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm=None,
) -> None:
"""Clip gradients. Raises error if clip_val > 0, otherwise it is a no-op.
Args:
optimizer: The optimizer to clip gradients for
clip_val: The value to clip gradients to
gradient_clip_algorithm: The algorithm to use for clipping
Raises:
ValueError: If clip_val > 0 since gradient clipping is handled by Mcore's optimizer
"""
if clip_val > 0.0:
raise ValueError(
"Gradient clipping is handled in Mcore's optimizer. Use the clip_grad attribute in OptimizerConfig."
)
def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value - it is a no-op.
Args:
optimizer: The optimizer to clip gradients for
clip_val: The value to clip gradients to
"""
return
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm - it is a no-op.
Args:
optimizer: The optimizer to clip gradients for
clip_val: The value to clip gradients to
"""
return
def update_config_with_dtype_overrides(dtype_config, config):
"""Update a config object with dtype settings from dtype_config.
Args:
dtype_config: Source of dtype settings
config: Config object to update
Returns:
Updated config object
"""
if hasattr(config, "__io__"):
config.__io__ = update_config_with_dtype_overrides(dtype_config, config.__io__)
for field in fields(dtype_config):
if not hasattr(config, field.name):
continue
# If we overwrote a value, log a debug message.
old_val = getattr(config, field.name)
new_val = getattr(dtype_config, field.name)
if old_val != new_val:
setattr(config, field.name, new_val)
logging.debug(f"Overwrote {type(config).__name__}.{field.name} {old_val} -> {new_val}")
return config
__all__ = ["MegatronMixedPrecision"]