# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Optional, Protocol, Sequence, Type, TypeVar, Union, runtime_checkable import fiddle as fdl import lightning.fabric as lb import lightning.pytorch as pl from torch import nn from typing_extensions import Self, override from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.lightning.io.mixin import IOMixin, serialization, track_io if TYPE_CHECKING: from megatron.core.optimizer import OptimizerConfig ModelT = TypeVar("ModelT", bound=nn.Module) class Fabric(lb.Fabric, IOMixin): def io_init(self, **kwargs) -> fdl.Config[Self]: # Each argument of the trainer can be stateful so we copy them cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items()} for val in cfg_kwargs.values(): if not serialization.find_node_traverser(type(val)): track_io(type(val)) return fdl.Config(type(self), **cfg_kwargs) def load_model( self, path: Union[str, Path], model: Optional[ModelT] = None, ) -> "DistributedModel[ModelT]": """Load and set up a model for distributed training. This method loads a model from the given path, sets it up for distributed training using the current Fabric instance, and returns a DistributedModel. Args: path (Union[str, Path]): The path to the saved model checkpoint. model (Optional[ModelT], optional): An optional pre-instantiated model. If not provided, the model will be loaded from the checkpoint. Defaults to None. Returns: DistributedModel[ModelT]: The loaded and distributed model. Example: >>> from nemo import lightning as nl >>> >>> trainer = nl.Trainer( ... devices=2, ... strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), ... plugins=nl.MegatronMixedPrecision(precision='16-mixed') ... ) >>> fabric = trainer.to_fabric() >>> distributed_model = fabric.load_model("path/to/checkpoint/dir") >>> >>> # You can now interact with the parallel model """ self.launch() from nemo.lightning.io import load_context path = Path(path) if model is None: context = load_context(ckpt_to_context_subdir(path)) model = context.model dist_model = self.setup_module(model) self.load(path, {"state_dict": dist_model}) return dist_model def import_model( self, path: Union[str, Path], model_type: Type[ModelT], ) -> "DistributedModel[ModelT]": """ Import a model from a given path and set it up for distributed training. This method imports a model of the specified type from the given path, loads it, and sets it up for distributed training using the current Fabric instance. Args: path (Union[str, Path]): The path to the model. Can be a local path or a Hugging Face model identifier. model_type (Type[ModelT]): The type of the model to import. Must be a subclass of ConnectorMixin. Returns: DistributedModel[ModelT]: The imported and distributed model. Raises: TypeError: If the provided model_type is not a subclass of ConnectorMixin. Example: >>> from nemo import lightning as nl >>> from nemo.collections.llm import MistralModel >>> >>> trainer = nl.Trainer( ... devices=2, ... strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), ... plugins=nl.MegatronMixedPrecision(precision='16-mixed') ... ) >>> fabric = trainer.to_fabric() >>> model = fabric.import_model("hf://mistralai/Mistral-7B-v0.1", MistralModel) >>> >>> # You can now interact with the parallel model """ from nemo.lightning.io import ConnectorMixin if not issubclass(model_type, ConnectorMixin): raise TypeError("The provided model class must be a subclass of ConnectorMixin") model: ModelT = model_type.import_from(path) return self.load_model(model.ckpt_path, model) @override def setup_module(self, module: nn.Module, move_to_device: bool = True, _reapply_compile: bool = True): from nemo.lightning.fabric.strategies import FabricMegatronStrategy out = super().setup_module(module, move_to_device=move_to_device, _reapply_compile=_reapply_compile) # We don't want to return a _FabricModule for megatron since we only want to precision convert # at the beginning and end of the pipeline if isinstance(self.strategy, FabricMegatronStrategy): return out._forward_module return out def setup_datamodule(self, datamodule: pl.LightningDataModule, stage: str = "") -> pl.LightningDataModule: datamodule.setup(stage) if hasattr(self.strategy, "process_datamodule"): datamodule = self.strategy.process_datamodule(datamodule) return datamodule @runtime_checkable class DistributedModel(Protocol[ModelT]): module: ModelT