Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # pylint: disable=missing-class-docstring | |
| # pylint: disable=missing-function-docstring | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass, fields | |
| from typing import Generator, Literal, TypeVar, Union | |
| import torch | |
| from lightning.pytorch.plugins.precision import Precision | |
| from torch.nn import Module | |
| from torch.optim import Optimizer | |
| from nemo.utils import logging | |
| AnyT = TypeVar("AnyT") | |
| def get_optim_config(optimizer: Optimizer): | |
| """Extract optimizer configurations from a Megatron optimizer. | |
| Args: | |
| optimizer: A torch.optim.Optimizer instance | |
| Yields: | |
| Optimizer configurations | |
| """ | |
| extract_config = lambda x: x.config | |
| try: | |
| from megatron.core.optimizer import ChainedOptimizer | |
| if isinstance(optimizer.mcore_optimizer, ChainedOptimizer): | |
| opts = optimizer.mcore_optimizer.chained_optimizers | |
| else: | |
| opts = [optimizer.mcore_optimizer] | |
| yield from map(extract_config, opts) | |
| except: | |
| raise ValueError("Failed to extract optimizer config from module.") | |
| class DtypeConfig: | |
| """Configuration class for mixed precision training settings. | |
| Contains settings for FP32/FP16/BF16 training, FP8 training. | |
| """ | |
| fp32: bool = False | |
| fp16: bool = False | |
| bf16: bool = False | |
| params_dtype: torch.dtype = None | |
| pipeline_dtype: torch.dtype = None | |
| autocast_dtype: torch.dtype = None | |
| autocast_enabled: bool = False | |
| grad_reduce_in_fp32: bool = True | |
| # fp8 related | |
| fp8: str = None | |
| fp8_recipe: str = "delayed" | |
| # fp4 related | |
| fp4: str = None | |
| fp4_recipe: str = "nvfp4" | |
| first_last_layers_bf16: bool = False | |
| fp8_margin: int = 0 | |
| fp8_amax_history_len: int = 1 | |
| fp8_amax_compute_algo: str = "most_recent" | |
| fp8_wgrad: bool = True | |
| fp8_dot_product_attention: bool = False | |
| fp8_multi_head_attention: bool = False | |
| fp8_param: bool = True | |
| fp8_param_gather: bool = True | |
| # FP16 Loss scaling | |
| loss_scale: float = (None,) | |
| initial_loss_scale: float = (None,) | |
| min_loss_scale: float = (None,) | |
| loss_scale_window: float = (None,) | |
| hysteresis: float = (None,) | |
| num_layers_at_start_in_bf16: int = 0 | |
| num_layers_at_end_in_bf16: int = 0 | |
| reuse_grad_buf_for_mxfp8_param_ag: bool = False | |
| class MegatronMixedPrecision(Precision): | |
| """Plugin for mixed precision training with Megatron models. | |
| Handles conversion of model parameters and inputs/outputs between different precisions, | |
| and manages mixed precision training settings. | |
| """ | |
| def __init__( | |
| self, | |
| precision: Literal["16-mixed", "bf16-mixed", "32"], | |
| params_dtype: torch.dtype = None, | |
| pipeline_dtype: torch.dtype = None, | |
| autocast_dtype: torch.dtype = None, | |
| autocast_enabled: bool = False, | |
| grad_reduce_in_fp32: bool = True, | |
| # fp8 related, | |
| fp8: str = None, | |
| fp8_recipe: str = "delayed", # "tensorwise", "delayed", "mxfp8" (for Blackwell only) | |
| first_last_layers_bf16: bool = False, | |
| fp8_margin: int = 0, | |
| fp8_amax_history_len: int = 1, | |
| fp8_amax_compute_algo: str = "most_recent", | |
| fp8_wgrad: bool = True, | |
| fp8_dot_product_attention: bool = False, | |
| fp8_multi_head_attention: bool = False, | |
| fp8_params: bool = None, | |
| fp8_param_gather: bool = None, | |
| # fp4 related | |
| fp4: str = None, | |
| fp4_recipe: str = "nvfp4", | |
| fp16_loss_scale: float = None, | |
| fp16_initial_loss_scale: float = 4294967296, | |
| fp16_min_loss_scale: float = 1.0, | |
| fp16_loss_scale_window: int = 1000, | |
| fp16_hysteresis: int = 2, | |
| num_layers_at_start_in_bf16: int = 0, | |
| num_layers_at_end_in_bf16: int = 0, | |
| reuse_grad_buf_for_mxfp8_param_ag: bool = False, | |
| ) -> None: | |
| if fp8_params is not None: | |
| logging.warning( | |
| "fp8_params is deprecated and will be removed in a future release, use fp8_param_gather instead" | |
| ) | |
| if fp8_param_gather is not None and fp8_param_gather != fp8_params: | |
| raise ValueError( | |
| "Getting conflicting values for fp8_params and fp8_param_gather. Please only set fp8_param_gather." | |
| ) | |
| fp8_param_gather = fp8_params | |
| elif fp8_param_gather is None: | |
| fp8_param_gather = False | |
| if isinstance(precision, int): | |
| precision = str(precision) | |
| dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 | |
| self.dtype_config = DtypeConfig( | |
| fp32=precision in ['fp32', '32'], | |
| fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'], | |
| bf16=precision in ['bf16', 'bf16-mixed'], | |
| params_dtype=params_dtype or torch.float32, | |
| pipeline_dtype=pipeline_dtype or dtype, | |
| autocast_dtype=autocast_dtype or dtype, | |
| autocast_enabled=autocast_enabled, | |
| grad_reduce_in_fp32=grad_reduce_in_fp32, | |
| fp8=fp8, | |
| fp8_recipe=fp8_recipe, | |
| first_last_layers_bf16=first_last_layers_bf16, | |
| fp8_margin=fp8_margin, | |
| fp8_amax_history_len=fp8_amax_history_len, | |
| fp8_amax_compute_algo=fp8_amax_compute_algo, | |
| fp8_wgrad=fp8_wgrad, | |
| fp8_dot_product_attention=fp8_dot_product_attention, | |
| fp8_multi_head_attention=fp8_multi_head_attention, | |
| fp8_param=fp8_param_gather, | |
| fp8_param_gather=fp8_param_gather, | |
| fp4=fp4, | |
| fp4_recipe=fp4_recipe, | |
| num_layers_at_start_in_bf16=num_layers_at_start_in_bf16, | |
| num_layers_at_end_in_bf16=num_layers_at_end_in_bf16, | |
| reuse_grad_buf_for_mxfp8_param_ag=reuse_grad_buf_for_mxfp8_param_ag, | |
| # fp16 loss scale | |
| loss_scale=fp16_loss_scale, | |
| initial_loss_scale=fp16_initial_loss_scale, | |
| min_loss_scale=fp16_min_loss_scale, | |
| loss_scale_window=fp16_loss_scale_window, | |
| hysteresis=fp16_hysteresis, | |
| ) | |
| super().__init__() | |
| if self.dtype_config.fp16: | |
| self.precision = "16-mixed" | |
| elif self.dtype_config.bf16: | |
| self.precision = "bf16-mixed" | |
| else: | |
| self.precision = "32-true" | |
| def convert_module(self, module: Module) -> Module: | |
| """Convert the module parameters to the precision type this plugin handles. | |
| This is optional and depends on the precision limitations during optimization. | |
| """ | |
| from megatron.core.transformer.module import Float16Module | |
| from megatron.core.utils import get_model_config | |
| if self.dtype_config.fp16 or self.dtype_config.bf16: | |
| # Patch config options | |
| config = get_model_config(module.module) | |
| config.fp16 = self.dtype_config.fp16 | |
| config.bf16 = self.dtype_config.bf16 | |
| # Avoid rewrapping the module if it's already of type Float16Module | |
| if hasattr(module, "module"): | |
| if not isinstance(module.module, Float16Module): | |
| module.module = Float16Module(config, module.module) | |
| elif not isinstance(module, Float16Module): | |
| module = Float16Module(config, module) | |
| return module | |
| def convert_optimizer(self, optimizer: Optimizer) -> Optimizer: | |
| """Convert the optimizer parameters to the precision type this plugin handles. | |
| This is optional and depends on the precision limitations during optimization. | |
| """ | |
| for optim_config in get_optim_config(optimizer): | |
| assert optim_config.bf16 == self.dtype_config.bf16, "BF16 model/optim config mismatch" | |
| assert optim_config.fp16 == self.dtype_config.fp16, "FP16 model/optim config mismatch" | |
| return optimizer | |
| def convert_input(self, data: AnyT) -> AnyT: | |
| """Convert model inputs (forward) to the floating point precision type of this plugin. | |
| Note: MegatronStrategy will take care of only doing this when: | |
| parallel_state.is_pipeline_first_stage() | |
| """ | |
| return data | |
| def convert_output(self, data: AnyT) -> AnyT: | |
| """Convert outputs to the floating point precision type expected after model's forward. | |
| Note: MegatronStrategy will take care of only doing this when: | |
| parallel_state.is_pipeline_last_stage() | |
| """ | |
| return data | |
| def forward_context(self) -> Generator[None, None, None]: | |
| """No explicit precision casting. Inputs are supposed to be manually casted.""" | |
| try: | |
| yield | |
| finally: | |
| pass | |
| def clip_gradients( | |
| self, | |
| optimizer: Optimizer, | |
| clip_val: Union[int, float] = 0.0, | |
| gradient_clip_algorithm=None, | |
| ) -> None: | |
| """Clip gradients. Raises error if clip_val > 0, otherwise it is a no-op. | |
| Args: | |
| optimizer: The optimizer to clip gradients for | |
| clip_val: The value to clip gradients to | |
| gradient_clip_algorithm: The algorithm to use for clipping | |
| Raises: | |
| ValueError: If clip_val > 0 since gradient clipping is handled by Mcore's optimizer | |
| """ | |
| if clip_val > 0.0: | |
| raise ValueError( | |
| "Gradient clipping is handled in Mcore's optimizer. Use the clip_grad attribute in OptimizerConfig." | |
| ) | |
| def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: | |
| """Clip gradients by value - it is a no-op. | |
| Args: | |
| optimizer: The optimizer to clip gradients for | |
| clip_val: The value to clip gradients to | |
| """ | |
| return | |
| def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: | |
| """Clip gradients by norm - it is a no-op. | |
| Args: | |
| optimizer: The optimizer to clip gradients for | |
| clip_val: The value to clip gradients to | |
| """ | |
| return | |
| def update_config_with_dtype_overrides(dtype_config, config): | |
| """Update a config object with dtype settings from dtype_config. | |
| Args: | |
| dtype_config: Source of dtype settings | |
| config: Config object to update | |
| Returns: | |
| Updated config object | |
| """ | |
| if hasattr(config, "__io__"): | |
| config.__io__ = update_config_with_dtype_overrides(dtype_config, config.__io__) | |
| for field in fields(dtype_config): | |
| if not hasattr(config, field.name): | |
| continue | |
| # If we overwrote a value, log a debug message. | |
| old_val = getattr(config, field.name) | |
| new_val = getattr(dtype_config, field.name) | |
| if old_val != new_val: | |
| setattr(config, field.name, new_val) | |
| logging.debug(f"Overwrote {type(config).__name__}.{field.name} {old_val} -> {new_val}") | |
| return config | |
| __all__ = ["MegatronMixedPrecision"] | |