# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Union from lightning.fabric.plugins.environments import slurm from lightning.pytorch import plugins as _pl_plugins # This is here to import it once, which improves the speed of launch when in debug-mode from nemo.utils.import_utils import safe_import safe_import("transformer_engine") from nemo.lightning.base import get_vocab_size, teardown from nemo.lightning.fabric.fabric import Fabric from nemo.lightning.fabric.plugins import FabricMegatronMixedPrecision from nemo.lightning.fabric.strategies import FabricMegatronStrategy from nemo.lightning.nemo_logger import NeMoLogger from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from nemo.lightning.pytorch.optim import ( LRSchedulerModule, MegatronOptimizerModule, OptimizerModule, PytorchOptimizerModule, lr_scheduler, ) from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import FSDP2Strategy, FSDPStrategy, MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop from nemo.lightning.resume import AutoResume # We monkey patch because nvidia uses a naming convention for SLURM jobs def _is_slurm_interactive_mode(): job_name = slurm.SLURMEnvironment.job_name() return job_name is None or job_name.endswith("bash") or job_name.endswith("interactive") slurm._is_slurm_interactive_mode = _is_slurm_interactive_mode # noqa: SLF001 _pl_plugins._PLUGIN_INPUT = Union[_pl_plugins._PLUGIN_INPUT, _data_sampler.DataSampler] # noqa: SLF001 __all__ = [ "AutoResume", "Fabric", "FabricMegatronMixedPrecision", "FabricMegatronStrategy", "LRSchedulerModule", "MegatronStrategy", "MegatronDataSampler", "MegatronMixedPrecision", "MegatronOptimizerModule", "PytorchOptimizerModule", "FSDPStrategy", "FSDP2Strategy", "RestoreConfig", "lr_scheduler", "NeMoLogger", "ModelCheckpoint", "OptimizerModule", "Trainer", "configure_no_restart_validation_training_loop", "get_vocab_size", "teardown", ]