Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| import itertools | |
| import os | |
| from collections import defaultdict | |
| from contextlib import contextmanager | |
| from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Mapping, Optional, Protocol, TypeVar | |
| import torch | |
| from torch import nn | |
| from nemo.lightning.megatron_init import initialize_model_parallel_for_nemo | |
| from nemo.utils import logging | |
| NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE" | |
| if TYPE_CHECKING: | |
| from lightning.fabric.utilities.types import Optimizable | |
| from megatron.core.model_parallel_config import ModelParallelConfig | |
| class SharedStateDictProtocol(Protocol): | |
| """ """ | |
| def sharded_state_dict(self, prefix="", metadata: Optional[dict] = None): | |
| """ """ | |
| ... | |
| def init_parallel_ranks( | |
| world_size: int, | |
| global_rank: int, | |
| local_rank: int, | |
| parallel_config: "ModelParallelConfig", | |
| seed=1234, | |
| fp8=False, | |
| ) -> None: | |
| """ | |
| Initializes the parallel ranks for distributed training. | |
| This function sets up the parallel ranks based on the provided world size, global rank, local rank, | |
| and parallel configuration. It also sets the seed for random number generation and determines whether | |
| to use fp8 precision. | |
| Args: | |
| world_size (int): The total number of processes participating in the distributed training. | |
| global_rank (int): The rank of the current process in the distributed training setup. | |
| local_rank (int): The rank of the current process within its machine. | |
| parallel_config (ModelParallelConfig): The configuration object containing settings for model parallelism. | |
| seed (int, optional): The seed for random number generation. Defaults to 1234. | |
| fp8 (bool, optional): Whether to use fp8 precision for model parameters. Defaults to False. | |
| """ | |
| from nemo.utils import AppState | |
| app_state = AppState() | |
| if os.environ.get(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, "false").lower() == "true": | |
| init_world_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size | |
| init_global_rank = app_state.global_rank | |
| init_local_rank = app_state.local_rank | |
| else: | |
| init_world_size = world_size | |
| pp = parallel_config.pipeline_model_parallel_size or 1 | |
| if world_size < pp: | |
| raise ValueError(f"Expected world_size ({world_size}) to be greater than/equal to pipeline size ({pp})") | |
| init_global_rank = global_rank | |
| init_local_rank = local_rank | |
| initialize_model_parallel_for_nemo( | |
| world_size=init_world_size, | |
| global_rank=init_global_rank, | |
| local_rank=init_local_rank, | |
| tensor_model_parallel_size=parallel_config.tensor_model_parallel_size, | |
| expert_model_parallel_size=parallel_config.expert_model_parallel_size, | |
| expert_tensor_parallel_size=parallel_config.expert_tensor_parallel_size, | |
| pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, | |
| pipeline_model_parallel_comm_backend=parallel_config.pipeline_model_parallel_comm_backend, | |
| virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size, | |
| context_parallel_size=parallel_config.context_parallel_size, | |
| seed=seed, | |
| use_fp8=fp8, | |
| init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False) | |
| and getattr(parallel_config, "tp_comm_bootstrap_backend", None) == 'mpi', | |
| use_te_rng_tracker=getattr(parallel_config, "use_te_rng_tracker", False), | |
| use_sharp=getattr(parallel_config, "use_sharp", False), | |
| use_tp_pp_dp_mapping=getattr(parallel_config, "use_tp_pp_dp_mapping", False), | |
| num_distributed_optimizer_instances=getattr(parallel_config, "num_distributed_optimizer_instances", 1), | |
| nccl_communicator_config_path=getattr(parallel_config, "nccl_communicator_config_path", None), | |
| use_gloo_process_groups=getattr(parallel_config, "use_gloo_process_groups", True), | |
| # apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), | |
| ) | |
| def init_model_parallel(model: Optional[nn.Module] = None) -> None: | |
| """Initializes Megatron-LM model parallel if using model parallelism.""" | |
| import torch.distributed | |
| from megatron.core import parallel_state | |
| from nemo.utils import AppState | |
| app_state = AppState() | |
| # we initialize megatron-lm model parallel and data parallel groups | |
| # after initializing DDP with PTL. | |
| if app_state.model_parallel_size is not None: | |
| # destroy groups in case they have already been created | |
| # this happens with multiple calls to trainer.test for example | |
| parallel_state.destroy_model_parallel() | |
| if torch.distributed.is_initialized(): | |
| parallel_state.initialize_model_parallel( | |
| tensor_model_parallel_size=app_state.tensor_model_parallel_size, | |
| pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, | |
| virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, | |
| pipeline_model_parallel_comm_backend=app_state.pipeline_model_parallel_comm_backend, | |
| context_parallel_size=app_state.context_parallel_size, | |
| expert_model_parallel_size=app_state.expert_model_parallel_size, | |
| expert_tensor_parallel_size=app_state.expert_tensor_parallel_size, | |
| use_sharp=app_state.use_sharp, | |
| order="tp-cp-ep-pp-dp" if app_state.use_tp_pp_dp_mapping else "tp-cp-ep-dp-pp", | |
| num_distributed_optimizer_instances=app_state.num_distributed_optimizer_instances, | |
| nccl_communicator_config_path=app_state.nccl_communicator_config_path, | |
| create_gloo_process_groups=app_state.use_gloo_process_groups, | |
| ) | |
| # assert that fake tp and pp rank match after model parallel init | |
| assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank() | |
| assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank() | |
| assert app_state.expert_tensor_parallel_rank == parallel_state.get_expert_tensor_parallel_rank() | |
| app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group() | |
| app_state.data_parallel_group = parallel_state.get_data_parallel_group() | |
| app_state.data_parallel_rank = parallel_state.get_data_parallel_rank() | |
| app_state.data_parallel_size = parallel_state.get_data_parallel_world_size() | |
| app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group() | |
| # create MPI process group for UCX-based communication APIs | |
| if app_state.init_mpi_proc_group: | |
| torch.distributed.new_group(backend="mpi") | |
| def set_model_parallel_attributes(model, parallelism): | |
| """ """ | |
| # Right now mcore sub-classes ModelParellelConfig, we should remove that | |
| # Given Lightning's structure it would be better if parallelism is a different object | |
| # Since then it can be passed to the Strategy | |
| # Note: Importing nemo.lightning.pytorch.strategies creates an import cycle. | |
| from megatron.core.transformer.transformer_config import TransformerConfig | |
| has_mcore_config = isinstance(getattr(model, "config", None), TransformerConfig) | |
| if has_mcore_config and hasattr(model, "configure_model"): | |
| config: TransformerConfig = model.config | |
| for attr_name in filter(lambda x: not x.startswith('__'), dir(parallelism)): | |
| if not hasattr(config, attr_name): | |
| continue | |
| setattr(config, attr_name, getattr(parallelism, attr_name)) | |
| if hasattr(config, "__io__"): | |
| setattr(config.__io__, attr_name, getattr(parallelism, attr_name)) | |
| if hasattr(config, '__post_init__'): | |
| # MCore does not use args in __post_init__ | |
| # @akoumparouli: is there a better way (e.g. reinit config)? | |
| config.__post_init__() | |
| return config | |
| return None | |
| def megatron_lazy_init_context(config) -> Generator[None, None, None]: | |
| """ """ | |
| try: | |
| from megatron.core.extensions import transformer_engine as _te | |
| original = _te._get_extra_te_kwargs # noqa: SLF001 | |
| def _get_extra_te_kwargs_meta(c): | |
| """Forces device to meta""" | |
| kwargs = original(c) | |
| kwargs['device'] = 'meta' | |
| return kwargs | |
| _te._get_extra_te_kwargs = _get_extra_te_kwargs_meta # noqa: SLF001 | |
| except ImportError: | |
| pass | |
| _orig_perform_initialization = config.perform_initialization | |
| _orig_use_cpu_initialization = config.use_cpu_initialization | |
| config.perform_initialization = False | |
| config.use_cpu_initialization = True | |
| yield | |
| try: | |
| from megatron.core.extensions import transformer_engine as _te | |
| _te._get_extra_te_kwargs = original # noqa: SLF001 | |
| except ImportError: | |
| pass | |
| config.perform_initialization = _orig_perform_initialization | |
| config.use_cpu_initialization = _orig_use_cpu_initialization | |
| def megatron_cpu_init_context(config) -> Generator[None, None, None]: | |
| """ """ | |
| _orig_use_cpu_initialization = config.use_cpu_initialization | |
| config.use_cpu_initialization = True | |
| yield | |
| config.use_cpu_initialization = _orig_use_cpu_initialization | |
| ModelT = TypeVar("ModelT", bound=nn.Module) | |
| class GradScaler(torch.cuda.amp.GradScaler): | |
| """ | |
| Gradient sclaer for model-parallel inf check. The inf in gradients are checked across tensor-parallel | |
| ranks in (1) executing optimizer step and (2) gradient scaler update. | |
| """ | |
| def __init__( | |
| self, | |
| init_scale=2.0**16, | |
| growth_factor=2.0, | |
| backoff_factor=0.5, | |
| growth_interval=2000, | |
| enabled=True, | |
| hysteresis=1, | |
| ): | |
| super().__init__( | |
| init_scale=init_scale, | |
| growth_factor=growth_factor, | |
| backoff_factor=backoff_factor, | |
| growth_interval=growth_interval, | |
| enabled=enabled, | |
| ) | |
| self.optimizer_update_skipped: Optional[bool] = None | |
| self.hysteresis = hysteresis | |
| self._hysteresis_tracker = self.hysteresis | |
| def _unscale_grads_(self, optimizer, *args): | |
| if getattr(optimizer, "_custom_amp_unscale_grads", False): | |
| return optimizer.unscale_grads(*args) | |
| else: | |
| return super()._unscale_grads_(optimizer, *args) | |
| def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): | |
| from megatron.core import parallel_state | |
| retval = None | |
| found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]) | |
| # Update across all model parallel instances. | |
| torch.distributed.all_reduce( | |
| found_inf, | |
| op=torch.distributed.ReduceOp.MAX, | |
| group=parallel_state.get_model_parallel_group(), | |
| ) | |
| if found_inf.item() == 0: | |
| retval = optimizer.step(*args, **kwargs) | |
| self.optimizer_update_skipped = False | |
| else: | |
| self.optimizer_update_skipped = True | |
| return retval | |
| def update(self, new_scale=None): | |
| """ | |
| Updates to native grad scaler update function. | |
| 1. Check inf across model-parallel ranks. | |
| 2. Update hysteresis tracker. | |
| 3. Apply hysteresis to grad scale update. | |
| """ | |
| from megatron.core import parallel_state | |
| if not self._enabled: | |
| return | |
| _scale, _growth_tracker = self._check_scale_growth_tracker("update") | |
| if new_scale is not None: | |
| # Accept a new user-defined scale. | |
| if isinstance(new_scale, float): | |
| self._scale.fill_(new_scale) # type: ignore[union-attr] | |
| else: | |
| reason = ( | |
| "new_scale should be a float or a 1-element torch.cuda.FloatTensor with" " requires_grad=False." | |
| ) | |
| assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] | |
| assert new_scale.numel() == 1, reason | |
| assert new_scale.requires_grad is False, reason | |
| self._scale.copy_(new_scale) # type: ignore[union-attr] | |
| else: | |
| # Consume shared inf/nan data collected from optimizers to update the scale. | |
| # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. | |
| found_infs = [ | |
| found_inf.to(device=_scale.device, non_blocking=True) | |
| for state in self._per_optimizer_states.values() | |
| for found_inf in state["found_inf_per_device"].values() | |
| ] | |
| assert len(found_infs) > 0, "No inf checks were recorded prior to update." | |
| found_inf_combined = found_infs[0] | |
| # Update across all model parallel instances. | |
| torch.distributed.all_reduce( | |
| found_inf_combined, | |
| op=torch.distributed.ReduceOp.MAX, | |
| group=parallel_state.get_model_parallel_group(), | |
| ) | |
| if len(found_infs) > 1: | |
| for i in range(1, len(found_infs)): | |
| found_inf = found_infs[i] | |
| # Update across all model parallel instances. | |
| torch.distributed.all_reduce( | |
| found_inf, | |
| op=torch.distributed.ReduceOp.MAX, | |
| group=parallel_state.get_model_parallel_group(), | |
| ) | |
| found_inf_combined += found_inf | |
| if found_inf_combined > 0: | |
| self._hysteresis_tracker -= 1 | |
| if self._hysteresis_tracker <= 0: | |
| # When hysteresis becomes zero, follow the native grad scale update rule. | |
| # Increase scale and reset growth tracker | |
| torch._amp_update_scale_( # noqa: SLF001 | |
| _scale, | |
| _growth_tracker, | |
| found_inf_combined, | |
| self._growth_factor, | |
| self._backoff_factor, | |
| self._growth_interval, | |
| ) | |
| else: | |
| # Only reset the growth tracker when hysteresis is larger than zero | |
| _growth_tracker.fill_(0.0) | |
| else: | |
| # When no inf found, follow the native grad scale update rule. | |
| # Increment growth_tracker, update scale when growth tracker reaches the interval, and | |
| # reset the hysteresis tracker. | |
| torch._amp_update_scale_( # noqa: SLF001 | |
| _scale, | |
| _growth_tracker, | |
| found_inf_combined, | |
| self._growth_factor, | |
| self._backoff_factor, | |
| self._growth_interval, | |
| ) | |
| self._hysteresis_tracker = self.hysteresis | |
| # To prepare for next iteration, clear the data collected from optimizers this iteration. | |
| self._per_optimizer_states = defaultdict( | |
| torch.cuda.amp.grad_scaler._refresh_per_optimizer_state # noqa: SLF001 | |
| ) | |
| def state_dict(self): | |
| """ | |
| Add hysteresis_tracker to the native functions' state_dict. | |
| """ | |
| return ( | |
| { | |
| "scale": self.get_scale(), | |
| "growth_factor": self._growth_factor, | |
| "backoff_factor": self._backoff_factor, | |
| "growth_interval": self._growth_interval, | |
| "_growth_tracker": self._get_growth_tracker(), | |
| "_hysteresis_tracker": self._hysteresis_tracker, | |
| } | |
| if self._enabled | |
| else {} | |
| ) | |
| def load_state_dict(self, state_dict): | |
| """ | |
| Load hysteresis_tracker in addition to the state dict of the native function. | |
| """ | |
| if not self._enabled: | |
| return | |
| if len(state_dict) == 0: | |
| raise RuntimeError( | |
| "The source state dict is empty, possibly because it was saved " | |
| "from a disabled instance of GradScaler." | |
| ) | |
| self._init_scale = state_dict["scale"] | |
| if self._scale is not None: | |
| self._scale.fill_(state_dict["scale"]) | |
| self._growth_factor = state_dict["growth_factor"] | |
| self._backoff_factor = state_dict["backoff_factor"] | |
| self._growth_interval = state_dict["growth_interval"] | |
| self._init_growth_tracker = state_dict["_growth_tracker"] | |
| if self._growth_tracker is not None: | |
| self._growth_tracker.fill_(state_dict["_growth_tracker"]) | |
| if "_hysterisis_tracker" in state_dict: | |
| self._hysteresis_tracker = state_dict["_hysterisis_tracker"] | |
| else: | |
| self._hysteresis_tracker = 1 | |
| def enable_nvidia_optimizations() -> None: | |
| """These optimizations are present in NVIDIA NGC PyTorch Containers.""" | |
| # NVIDIA container version check | |
| nvidia_torch_version = os.getenv("NVIDIA_PYTORCH_VERSION", None) | |
| if nvidia_torch_version is not None: | |
| try: | |
| NVIDIA_TORCH_MAJOR = int(nvidia_torch_version.split(".")[0]) | |
| except Exception: | |
| NVIDIA_TORCH_MAJOR = 0 | |
| try: | |
| NVIDIA_TORCH_MINOR = int(nvidia_torch_version.split(".")[1]) | |
| except Exception: | |
| NVIDIA_TORCH_MINOR = 0 | |
| # NVFUSER available starting with 21.11 | |
| if NVIDIA_TORCH_MAJOR >= 21 or (NVIDIA_TORCH_MAJOR == 21 and NVIDIA_TORCH_MINOR >= 11): | |
| # NVFUSER | |
| torch._C._jit_set_profiling_executor(True) # noqa: SLF001 | |
| torch._C._jit_set_profiling_mode(True) # noqa: SLF001 | |
| torch._C._jit_override_can_fuse_on_cpu(False) # noqa: SLF001 | |
| torch._C._jit_override_can_fuse_on_gpu(False) # noqa: SLF001 | |
| torch._C._jit_set_texpr_fuser_enabled(False) # noqa: SLF001 | |
| # torch._C._jit_set_nvfuser_enabled(True) | |
| torch._C._debug_set_autodiff_subgraph_inlining(False) # noqa: SLF001 | |
| else: | |
| # Not a Nvidia container. NVFUSER Dependency check is on users | |
| pass | |
| def optimizer_sharded_state_dict( | |
| model: SharedStateDictProtocol, | |
| optimizer: "Optimizable", | |
| is_loading: bool = False, | |
| sharding_type: Optional[str] = None, | |
| metadata: Optional[dict] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Sharded state dictionary for an MainParamsOptimizerWrapper. | |
| Used to save and load the optimizer state when training with distributed_checkpoint. | |
| Args: | |
| model (SharedStateDictProtocol): model with a `sharded_state_dict` method | |
| optimizer (Optimizable): optimizer to get the state dict of | |
| is_loading (bool, optional): set to True if the sharded state dict is intended | |
| for checkpoint loading (as opposed to saving). Defaults to False. | |
| sharding_type (str, optional): deprecated, use metadata flags instead. | |
| metadata (dict, optional): sharded state dict metadata passed from the framework. | |
| Used to control the details of sharded state dict creation, in particular | |
| the state dict format of the DistributedOptimizer with the flag | |
| `distrib_optim_sharding_type`. Defaults to None (empty metadata). | |
| Returns | |
| ------- | |
| dict: The sharded state dictionary for the optimizer | |
| Raises: | |
| ValueError: If a parameter ID does not match any model sharded parameter. | |
| """ | |
| from megatron.core.dist_checkpointing.optimizer import ( | |
| get_param_id_to_sharded_param_map, | |
| make_sharded_optimizer_tensor, | |
| optim_state_to_sharding_state, | |
| ) | |
| from nemo.core.optim import MainParamsOptimizerWrapper | |
| from nemo.core.optim.optimizers import init_optimizer_states | |
| model_sharded_state_dict = model.sharded_state_dict(metadata=metadata) | |
| # remove _extra_state | |
| model_sharded_state_dict = { | |
| key: value for key, value in model_sharded_state_dict.items() if not key.endswith("_extra_state") | |
| } | |
| if sharding_type is not None: | |
| logging.warning("sharding_type is deprecated, please use `metadata['distrib_optim_sharding_type']` instead") | |
| if metadata is None: | |
| metadata = {} | |
| if 'distrib_optim_sharding_type' not in metadata: | |
| metadata["distrib_optim_sharding_type"] = sharding_type | |
| if hasattr(optimizer, "sharded_state_dict"): | |
| return optimizer.sharded_state_dict( | |
| model_sharded_state_dict, | |
| is_loading=is_loading, | |
| metadata=metadata, | |
| ) | |
| if not isinstance(optimizer, MainParamsOptimizerWrapper): | |
| # Regular optimizer, e.g. Adam or FusedAdam | |
| init_optimizer_states(optimizer) | |
| optimizer_state_dict = optimizer.state_dict() | |
| id_to_sharded_param_map = get_param_id_to_sharded_param_map( | |
| model_sharded_state_dict=model_sharded_state_dict, | |
| optim_params_iter=itertools.chain.from_iterable(g['params'] for g in optimizer.param_groups), | |
| ) | |
| optim_state_to_sharding_state(optimizer_state_dict, id_to_sharded_param_map) | |
| return optimizer_state_dict | |
| optimizer_state_dict: Dict[str, Any] = optimizer.state_dict() | |
| id_to_sharded_param_map = get_param_id_to_sharded_param_map( | |
| model_sharded_state_dict=model_sharded_state_dict, | |
| optim_params_iter=itertools.chain.from_iterable(g for g in optimizer.float16_groups), | |
| ) | |
| # Convert fp32_from_fp16_params | |
| assert len(optimizer_state_dict["fp32_from_fp16_params"]) == len(optimizer_state_dict["optimizer"]["param_groups"]) | |
| def get_safe(param_id): | |
| try: | |
| return id_to_sharded_param_map[param_id] | |
| except KeyError as e: | |
| raise ValueError(f"Param id {param_id} does not match any model sharded param") from e | |
| optimizer_state_dict["fp32_from_fp16_params"] = [ | |
| [ | |
| make_sharded_optimizer_tensor(get_safe(param_id), fp32_param, prefix="optimizer.state.fp32_param") | |
| for param_id, fp32_param in zip(state_group["params"], fp32_group) | |
| ] | |
| for fp32_group, state_group in zip( | |
| optimizer_state_dict["fp32_from_fp16_params"], | |
| optimizer_state_dict["optimizer"]["param_groups"], | |
| ) | |
| ] | |
| # Convert state | |
| optim_state_to_sharding_state(optimizer_state_dict["optimizer"], id_to_sharded_param_map) | |
| return optimizer_state_dict | |
| def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], strict: bool = True) -> None: | |
| """ """ | |
| from megatron.core import parallel_state | |
| from megatron.core.dist_checkpointing.validation import StrictHandling, parse_strict_flag | |
| # convert from StrictHandling to bool for PTL | |
| if strict is not None and not isinstance(strict, bool): | |
| strict = parse_strict_flag(strict) | |
| strict_options = [ | |
| StrictHandling.ASSUME_OK_UNEXPECTED, | |
| StrictHandling.RAISE_UNEXPECTED, | |
| StrictHandling.RAISE_ALL, | |
| ] | |
| strict = strict in strict_options | |
| try: | |
| from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel | |
| have_custom_fsdp = True | |
| except ImportError or ModuleNotFoundError: | |
| have_custom_fsdp = False | |
| try: | |
| from megatron.core.distributed import FullyShardedDataParallel | |
| have_megatron_fsdp = True | |
| except ImportError or ModuleNotFoundError: | |
| have_megatron_fsdp = False | |
| for index, module in enumerate(megatron_parallel): | |
| if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: | |
| if "state_dict" in checkpoint: | |
| checkpoint_state_dict = checkpoint["state_dict"][f"model_{index}"] | |
| else: | |
| checkpoint_state_dict = checkpoint[f"model_{index}"] | |
| else: | |
| if "state_dict" in checkpoint: | |
| checkpoint_state_dict = checkpoint["state_dict"] | |
| else: | |
| checkpoint_state_dict = checkpoint | |
| n_nesting = 0 | |
| mcore_model = megatron_parallel.module | |
| while hasattr(mcore_model, "module"): | |
| mcore_model = mcore_model.module | |
| n_nesting += 1 | |
| _state_dict = {} | |
| for key, value in checkpoint_state_dict.items(): | |
| # Count the number of "module." at the start of the key | |
| count, _key = 0, key | |
| while _key.startswith("module."): | |
| _key = _key[len("module.") :] | |
| count += 1 | |
| # Adjust the number of "module." prefixes | |
| if count < n_nesting: | |
| to_add = "module." * (n_nesting - count) | |
| _state_dict[f"{to_add}{key}"] = value | |
| elif count > n_nesting: | |
| to_remove = "module." * (count - n_nesting) | |
| _state_dict[key[len(to_remove) :]] = value | |
| else: | |
| _state_dict[key] = value | |
| if have_custom_fsdp and hasattr(module, "module") and isinstance(module.module, FullyShardedDataParallel): | |
| module.module.load_state_dict(_state_dict, strict=strict) | |
| elif have_megatron_fsdp and hasattr(module, "module") and isinstance(module.module, FullyShardedDataParallel): | |
| module.module.load_state_dict(_state_dict, strict=strict) | |
| continue | |
| try: | |
| module.load_state_dict(_state_dict, strict=strict) | |
| except RuntimeError as e: | |
| missing_keys, expected_keys = module.load_state_dict(checkpoint_state_dict, strict=False) | |
| if all(s.endswith('_extra_state') for s in missing_keys): | |
| logging.warning( | |
| f'Loding checkpoint created with Transformer Engine version lower than 1.13. ' | |
| f'Missing layers {missing_keys} will be ignored.' | |
| ) | |
| else: | |
| raise e | |
| def _sync_from_last_pipeline_stage(value: torch.Tensor, broadcast: bool = False): | |
| """ | |
| When pipeline parallelism is enabled, | |
| casts a tensor defined on the last pipeline stage to other ranks. | |
| Args: | |
| value (torch.Tensor): A tensor to be casted from the final pipeline stage of | |
| a pipeline parallelism group (e.g. loss). | |
| Note that this tensor should already be defined on the target rank(s) to fill with received data. | |
| broadcast (bool): When True, broadcasts value from the final pipeline stage rank to all ranks in its group. | |
| When False, only rank zero receives value from the final pipeline stage rank in its group. | |
| This mode exists to avoid slow one-to-many communication when not necessary. Defaults to False. | |
| """ | |
| from megatron.core import parallel_state | |
| if parallel_state.get_pipeline_model_parallel_world_size() > 1: | |
| src_rank = parallel_state.get_pipeline_model_parallel_last_rank() | |
| if not broadcast: | |
| pp_ranks = torch.distributed.get_process_group_ranks(parallel_state.get_pipeline_model_parallel_group()) | |
| if torch.distributed.get_rank() == src_rank and 0 in pp_ranks: | |
| torch.distributed.send(value, 0) | |
| elif torch.distributed.get_rank() == 0: | |
| torch.distributed.recv(value, src_rank) | |
| else: | |
| torch.distributed.broadcast( | |
| value, | |
| src_rank, | |
| group=parallel_state.get_pipeline_model_parallel_group(), | |
| ) | |
| def setup_megatron_optimizer( | |
| model, | |
| config, | |
| no_weight_decay_cond: Optional[Callable] = None, | |
| scale_lr_cond: Optional[Callable] = None, | |
| lr_mult: float = 1.0, | |
| ): | |
| """ """ | |
| from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer | |
| from nemo.core.optim import McoreDistributedOptimizer | |
| from nemo.utils import AppState | |
| app_state = AppState() | |
| assert isinstance(config, OptimizerConfig), f"Expected OptimizerConfig, got {type(config)}" | |
| class McoreOpt(McoreDistributedOptimizer): | |
| """ """ | |
| def sharded_state_dict( | |
| self, | |
| model_sharded_state_dict, | |
| optimizer_state_dict=None, | |
| is_loading=False, | |
| sharding_type='fully_sharded_model_space', | |
| metadata=None, | |
| ): | |
| mcore_optimizer_sig = inspect.signature(self.mcore_optimizer.sharded_state_dict).parameters | |
| distrib_optim_kwargs = {} | |
| if "metadata" in mcore_optimizer_sig or "kwargs" in mcore_optimizer_sig: | |
| distrib_optim_kwargs["metadata"] = metadata | |
| elif "sharding_type" in mcore_optimizer_sig: | |
| distrib_optim_kwargs["sharding_type"] = sharding_type | |
| state_dict = self.mcore_optimizer.sharded_state_dict( | |
| model_sharded_state_dict, is_loading=is_loading, **distrib_optim_kwargs | |
| ) | |
| return state_dict | |
| # megatron optimizer expects McoreDDP | |
| ddp_modules = [m.module for m in model] | |
| mcore_opt = get_megatron_optimizer( | |
| config, | |
| ddp_modules, | |
| no_weight_decay_cond=no_weight_decay_cond, | |
| scale_lr_cond=scale_lr_cond, | |
| lr_mult=lr_mult, | |
| use_gloo_process_groups=app_state.use_gloo_process_groups, | |
| ) | |
| # Pytorch does not have the concept of an `lr_mult` or a `wd_mult` but these are added to param | |
| # groups in megatron to control which sub-modules have different learning rates or weight | |
| # decays. Apply the multipliers here to each param_group's lr and wd, and to reduce confusion | |
| # change the name of these variables. We need this because nemo does not use the custom | |
| # megatron scheduler, and the megatron scheduler is what makes use of these mult parameters: | |
| # https://github.com/NVIDIA/Megatron-LM/blob/044e2ad5/megatron/core/optimizer_param_scheduler.py#L192-L193 | |
| for pg in mcore_opt.param_groups: | |
| if 'pre_lr_mult' in pg or 'pre_mult_wd' in pg: | |
| # User has already applied custom lr and wd multipliers, don't apply `lr_mult` and | |
| # `wd_mult` again. This case may be encountered when resuming training. | |
| continue | |
| pg['pre_mult_lr'] = pg["lr"] | |
| pg['pre_mult_wd'] = pg['weight_decay'] | |
| new_lr = pg["lr"] * pg.get('lr_mult', 1.0) | |
| new_wd = pg["weight_decay"] * pg.get("wd_mult", 1.0) | |
| pg['lr'] = new_lr | |
| pg['weight_decay'] = new_wd | |
| # In case a future implementation makes use of `lr_mult` and `wd_mult` directly in the | |
| # scheduler, but accidentally also uses this function, remove `lr_mult` and `wd_mult` from | |
| # the param groups so that the default value of 1.0 gets applied. | |
| if 'lr_mult' in pg: | |
| pg['pre_lr_mult'] = pg['lr_mult'] | |
| del pg['lr_mult'] # remove so downstream methods do not apply again. | |
| if 'wd_mult' in pg: | |
| pg['pre_wd_mult'] = pg['wd_mult'] | |
| del pg['wd_mult'] # remove so downstream methods do not apply again | |
| if getattr(model.ddp_config, "overlap_param_gather", False) and getattr( | |
| model.ddp_config, "align_param_gather", False | |
| ): | |
| param_sync_func = [model_chunk.start_param_sync for model_chunk in model] | |
| param_sync_func = param_sync_func[0] if len(model) == 1 else param_sync_func | |
| for module in model: | |
| module.config.param_sync_func = param_sync_func | |
| return McoreOpt(mcore_opt) | |