# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=missing-class-docstring # pylint: disable=missing-function-docstring from contextlib import contextmanager from typing import TYPE_CHECKING, Generator, Literal, Optional, TypeVar import torch from lightning.fabric.plugins.precision import MixedPrecision from torch import nn from torch.optim import Optimizer from nemo.lightning.fabric.conversion import to_fabric from nemo.lightning.pytorch.plugins.mixed_precision import ( DtypeConfig, MegatronMixedPrecision, get_optim_config, update_config_with_dtype_overrides, ) from nemo.utils import logging if TYPE_CHECKING: from megatron.core.model_parallel_config import ModelParallelConfig AnyT = TypeVar("AnyT") ConfigT = TypeVar("ConfigT", bound="ModelParallelConfig") class FabricMegatronMixedPrecision(MixedPrecision): """Fabric plugin for mixed precision training with Megatron models. Handles precision conversions and mixed precision training settings in the Fabric training framework. """ def __init__( self, precision: Literal["16-mixed", "bf16-mixed", "32"], params_dtype: torch.dtype = None, pipeline_dtype: torch.dtype = None, autocast_dtype: torch.dtype = None, autocast_enabled: bool = False, grad_reduce_in_fp32: bool = True, # fp8 related, fp8: str = None, fp8_recipe: Optional[str] = None, first_last_layers_bf16: bool = False, num_layers_at_start_in_bf16: int = 0, num_layers_at_end_in_bf16: int = 0, reuse_grad_buf_for_mxfp8_param_ag: bool = False, fp8_margin: int = 0, fp8_amax_history_len: int = 1, fp8_amax_compute_algo: str = "most_recent", fp8_wgrad: bool = True, fp8_dot_product_attention: bool = False, fp8_multi_head_attention: bool = False, fp8_params: bool = None, fp8_param_gather: bool = None, fp16_loss_scale: float = None, fp16_initial_loss_scale: float = 4294967296, fp16_min_loss_scale: float = 1.0, fp16_loss_scale_window: int = 1000, fp16_hysteresis: int = 2, ) -> None: if fp8_params is not None: logging.warning( "fp8_params is deprecated and will be removed in a future release, use fp8_param_gather instead" ) if fp8_param_gather is not None and fp8_param_gather != fp8_params: raise ValueError( "Getting conflicting values for fp8_params and fp8_param_gather. Please only set fp8_param_gather." ) fp8_param_gather = fp8_params elif fp8_param_gather is None: fp8_param_gather = False if isinstance(precision, int): precision = str(precision) dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 self.dtype_config = DtypeConfig( fp32=precision in ['fp32', '32'], fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'], bf16=precision in ['bf16', 'bf16-mixed'], params_dtype=params_dtype or torch.float32, pipeline_dtype=pipeline_dtype or dtype, autocast_dtype=autocast_dtype or dtype, autocast_enabled=autocast_enabled, grad_reduce_in_fp32=grad_reduce_in_fp32, fp8=fp8, fp8_recipe=fp8_recipe, first_last_layers_bf16=first_last_layers_bf16, num_layers_at_start_in_bf16=num_layers_at_start_in_bf16, num_layers_at_end_in_bf16=num_layers_at_end_in_bf16, reuse_grad_buf_for_mxfp8_param_ag=reuse_grad_buf_for_mxfp8_param_ag, fp8_margin=fp8_margin, fp8_amax_history_len=fp8_amax_history_len, fp8_amax_compute_algo=fp8_amax_compute_algo, fp8_wgrad=fp8_wgrad, fp8_dot_product_attention=fp8_dot_product_attention, fp8_multi_head_attention=fp8_multi_head_attention, fp8_param=fp8_param_gather, fp8_param_gather=fp8_param_gather, # fp16 loss scale loss_scale=fp16_loss_scale, initial_loss_scale=fp16_initial_loss_scale, min_loss_scale=fp16_min_loss_scale, loss_scale_window=fp16_loss_scale_window, hysteresis=fp16_hysteresis, ) if self.dtype_config.fp16: self.precision = "16-mixed" elif self.dtype_config.bf16: self.precision = "bf16-mixed" else: self.precision = "32-true" self.scaler = None def convert_input(self, data: AnyT) -> AnyT: """Convert model inputs (forward) to the floating point precision type of this plugin. Note: MegatronStrategy will take care of only doing this when: mpu.is_pipeline_first_stage() """ return data def convert_output(self, data: AnyT) -> AnyT: """Convert outputs to the floating point precision type expected after model's forward. Note: MegatronStrategy will take care of only doing this when: mpu.is_pipeline_first_stage() """ return data def convert_config(self, config: ConfigT) -> ConfigT: """Convert the config to the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. """ return update_config_with_dtype_overrides(self.dtype_config, config) def convert_module(self, module: nn.Module) -> nn.Module: """Convert the module parameters to the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. """ if not hasattr(module, "module"): return module from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_model_config if self.dtype_config.fp16 or self.dtype_config.bf16: # Patch config options config = get_model_config(module.module) config.fp16 = self.dtype_config.fp16 config.bf16 = self.dtype_config.bf16 # Avoid rewrapping the module if it's already of type Float16Module if hasattr(module, "module"): if not isinstance(module.module, Float16Module): module.module = Float16Module(config, module.module) elif not isinstance(module, Float16Module): module = Float16Module(config, module) return module def convert_optimizer(self, optimizer: Optimizer) -> Optimizer: """Convert the optimizer parameters to the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. """ for optim_config in get_optim_config(optimizer): assert optim_config.bf16 == self.dtype_config.bf16, "BF16 model/optim config mismatch" assert optim_config.fp16 == self.dtype_config.fp16, "FP16 model/optim config mismatch" return optimizer @contextmanager def forward_context(self) -> Generator[None, None, None]: """No explicit precision casting. Inputs are supposed to be manually casted.""" try: yield finally: pass @to_fabric.register(MegatronMixedPrecision) def _convert_megatron_mixed_precision(plugin: MegatronMixedPrecision) -> FabricMegatronMixedPrecision: return FabricMegatronMixedPrecision( precision=plugin.precision, )