# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=missing-class-docstring # pylint: disable=missing-function-docstring from contextlib import contextmanager from dataclasses import dataclass, fields from typing import Generator, Literal, TypeVar, Union import torch from lightning.pytorch.plugins.precision import Precision from torch.nn import Module from torch.optim import Optimizer from nemo.utils import logging AnyT = TypeVar("AnyT") def get_optim_config(optimizer: Optimizer): """Extract optimizer configurations from a Megatron optimizer. Args: optimizer: A torch.optim.Optimizer instance Yields: Optimizer configurations """ extract_config = lambda x: x.config try: from megatron.core.optimizer import ChainedOptimizer if isinstance(optimizer.mcore_optimizer, ChainedOptimizer): opts = optimizer.mcore_optimizer.chained_optimizers else: opts = [optimizer.mcore_optimizer] yield from map(extract_config, opts) except: raise ValueError("Failed to extract optimizer config from module.") @dataclass class DtypeConfig: """Configuration class for mixed precision training settings. Contains settings for FP32/FP16/BF16 training, FP8 training. """ fp32: bool = False fp16: bool = False bf16: bool = False params_dtype: torch.dtype = None pipeline_dtype: torch.dtype = None autocast_dtype: torch.dtype = None autocast_enabled: bool = False grad_reduce_in_fp32: bool = True # fp8 related fp8: str = None fp8_recipe: str = "delayed" # fp4 related fp4: str = None fp4_recipe: str = "nvfp4" first_last_layers_bf16: bool = False fp8_margin: int = 0 fp8_amax_history_len: int = 1 fp8_amax_compute_algo: str = "most_recent" fp8_wgrad: bool = True fp8_dot_product_attention: bool = False fp8_multi_head_attention: bool = False fp8_param: bool = True fp8_param_gather: bool = True # FP16 Loss scaling loss_scale: float = (None,) initial_loss_scale: float = (None,) min_loss_scale: float = (None,) loss_scale_window: float = (None,) hysteresis: float = (None,) num_layers_at_start_in_bf16: int = 0 num_layers_at_end_in_bf16: int = 0 reuse_grad_buf_for_mxfp8_param_ag: bool = False class MegatronMixedPrecision(Precision): """Plugin for mixed precision training with Megatron models. Handles conversion of model parameters and inputs/outputs between different precisions, and manages mixed precision training settings. """ def __init__( self, precision: Literal["16-mixed", "bf16-mixed", "32"], params_dtype: torch.dtype = None, pipeline_dtype: torch.dtype = None, autocast_dtype: torch.dtype = None, autocast_enabled: bool = False, grad_reduce_in_fp32: bool = True, # fp8 related, fp8: str = None, fp8_recipe: str = "delayed", # "tensorwise", "delayed", "mxfp8" (for Blackwell only) first_last_layers_bf16: bool = False, fp8_margin: int = 0, fp8_amax_history_len: int = 1, fp8_amax_compute_algo: str = "most_recent", fp8_wgrad: bool = True, fp8_dot_product_attention: bool = False, fp8_multi_head_attention: bool = False, fp8_params: bool = None, fp8_param_gather: bool = None, # fp4 related fp4: str = None, fp4_recipe: str = "nvfp4", fp16_loss_scale: float = None, fp16_initial_loss_scale: float = 4294967296, fp16_min_loss_scale: float = 1.0, fp16_loss_scale_window: int = 1000, fp16_hysteresis: int = 2, num_layers_at_start_in_bf16: int = 0, num_layers_at_end_in_bf16: int = 0, reuse_grad_buf_for_mxfp8_param_ag: bool = False, ) -> None: if fp8_params is not None: logging.warning( "fp8_params is deprecated and will be removed in a future release, use fp8_param_gather instead" ) if fp8_param_gather is not None and fp8_param_gather != fp8_params: raise ValueError( "Getting conflicting values for fp8_params and fp8_param_gather. Please only set fp8_param_gather." ) fp8_param_gather = fp8_params elif fp8_param_gather is None: fp8_param_gather = False if isinstance(precision, int): precision = str(precision) dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 self.dtype_config = DtypeConfig( fp32=precision in ['fp32', '32'], fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'], bf16=precision in ['bf16', 'bf16-mixed'], params_dtype=params_dtype or torch.float32, pipeline_dtype=pipeline_dtype or dtype, autocast_dtype=autocast_dtype or dtype, autocast_enabled=autocast_enabled, grad_reduce_in_fp32=grad_reduce_in_fp32, fp8=fp8, fp8_recipe=fp8_recipe, first_last_layers_bf16=first_last_layers_bf16, fp8_margin=fp8_margin, fp8_amax_history_len=fp8_amax_history_len, fp8_amax_compute_algo=fp8_amax_compute_algo, fp8_wgrad=fp8_wgrad, fp8_dot_product_attention=fp8_dot_product_attention, fp8_multi_head_attention=fp8_multi_head_attention, fp8_param=fp8_param_gather, fp8_param_gather=fp8_param_gather, fp4=fp4, fp4_recipe=fp4_recipe, num_layers_at_start_in_bf16=num_layers_at_start_in_bf16, num_layers_at_end_in_bf16=num_layers_at_end_in_bf16, reuse_grad_buf_for_mxfp8_param_ag=reuse_grad_buf_for_mxfp8_param_ag, # fp16 loss scale loss_scale=fp16_loss_scale, initial_loss_scale=fp16_initial_loss_scale, min_loss_scale=fp16_min_loss_scale, loss_scale_window=fp16_loss_scale_window, hysteresis=fp16_hysteresis, ) super().__init__() if self.dtype_config.fp16: self.precision = "16-mixed" elif self.dtype_config.bf16: self.precision = "bf16-mixed" else: self.precision = "32-true" def convert_module(self, module: Module) -> Module: """Convert the module parameters to the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. """ from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_model_config if self.dtype_config.fp16 or self.dtype_config.bf16: # Patch config options config = get_model_config(module.module) config.fp16 = self.dtype_config.fp16 config.bf16 = self.dtype_config.bf16 # Avoid rewrapping the module if it's already of type Float16Module if hasattr(module, "module"): if not isinstance(module.module, Float16Module): module.module = Float16Module(config, module.module) elif not isinstance(module, Float16Module): module = Float16Module(config, module) return module def convert_optimizer(self, optimizer: Optimizer) -> Optimizer: """Convert the optimizer parameters to the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. """ for optim_config in get_optim_config(optimizer): assert optim_config.bf16 == self.dtype_config.bf16, "BF16 model/optim config mismatch" assert optim_config.fp16 == self.dtype_config.fp16, "FP16 model/optim config mismatch" return optimizer def convert_input(self, data: AnyT) -> AnyT: """Convert model inputs (forward) to the floating point precision type of this plugin. Note: MegatronStrategy will take care of only doing this when: parallel_state.is_pipeline_first_stage() """ return data def convert_output(self, data: AnyT) -> AnyT: """Convert outputs to the floating point precision type expected after model's forward. Note: MegatronStrategy will take care of only doing this when: parallel_state.is_pipeline_last_stage() """ return data @contextmanager def forward_context(self) -> Generator[None, None, None]: """No explicit precision casting. Inputs are supposed to be manually casted.""" try: yield finally: pass def clip_gradients( self, optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm=None, ) -> None: """Clip gradients. Raises error if clip_val > 0, otherwise it is a no-op. Args: optimizer: The optimizer to clip gradients for clip_val: The value to clip gradients to gradient_clip_algorithm: The algorithm to use for clipping Raises: ValueError: If clip_val > 0 since gradient clipping is handled by Mcore's optimizer """ if clip_val > 0.0: raise ValueError( "Gradient clipping is handled in Mcore's optimizer. Use the clip_grad attribute in OptimizerConfig." ) def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by value - it is a no-op. Args: optimizer: The optimizer to clip gradients for clip_val: The value to clip gradients to """ return def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by norm - it is a no-op. Args: optimizer: The optimizer to clip gradients for clip_val: The value to clip gradients to """ return def update_config_with_dtype_overrides(dtype_config, config): """Update a config object with dtype settings from dtype_config. Args: dtype_config: Source of dtype settings config: Config object to update Returns: Updated config object """ if hasattr(config, "__io__"): config.__io__ = update_config_with_dtype_overrides(dtype_config, config.__io__) for field in fields(dtype_config): if not hasattr(config, field.name): continue # If we overwrote a value, log a debug message. old_val = getattr(config, field.name) new_val = getattr(dtype_config, field.name) if old_val != new_val: setattr(config, field.name, new_val) logging.debug(f"Overwrote {type(config).__name__}.{field.name} {old_val} -> {new_val}") return config __all__ = ["MegatronMixedPrecision"]