# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import math from dataclasses import dataclass, field from typing import Literal, Optional import torch from megatron.core import parallel_state from nemo.utils.import_utils import safe_import if torch.cuda.is_available(): bitsandbytes, HAVE_BNB = safe_import("bitsandbytes") else: bitsandbytes = None HAVE_BNB = False import torch.nn.functional as F from torch import nn from nemo.utils.import_utils import safe_import_from te, HAVE_TE = safe_import_from("transformer_engine", "pytorch") from nemo.collections.llm.peft.module_matcher import ModuleMatcher from nemo.collections.llm.peft.utils import get_adapter_attributes_from_linear, is_expert_linear from nemo.lightning.pytorch.callbacks.peft import PEFT, AdapterWrapper from nemo.utils import logging from nemo.utils.te_utils import te_version class LoRALinear(AdapterWrapper): """An adapter wrapper that adds the output of the adapter to the output of the wrapped module. This class is designed to be used with LoRA (Low-Rank Adaptation) and similar techniques where the adapter's output is added to the main module's output. It extends the AdapterWrapper class to provide a specific implementation of the forward method. """ def forward( self, x: torch.Tensor, *args, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: # pylint: disable=C0115,C0116 linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs) adapter_output = self.adapter(layernorm_output.contiguous()) adapter_output = adapter_output.reshape(linear_output.shape) return linear_output + adapter_output, bias # Fused LoRA requires Transformer Engine 2.7+ HAVE_TE_FUSED_LORA: bool = HAVE_TE and te_version() >= (2, 7) if HAVE_TE_FUSED_LORA: class TEFusedLoRALinear(LoRALinear): """LoRA adapter wrapper using Transformer Engine operation fuser""" def __init__(self, to_wrap: nn.Module, adapter: nn.Module): super().__init__(to_wrap, adapter) self._fused_branches: Optional[tuple[te.ops.Sequential, te.ops.Sequential]] = None def _make_fused_branches(self) -> tuple[te.ops.Sequential, te.ops.Sequential]: """Construct fused modules for main and LoRA branches""" # Extract layer size and tensor parallel config kwargs = { "in_features": self.to_wrap.weight.size(1), "out_features": self.to_wrap.weight.size(0), "tensor_parallel_mode": None, "tensor_parallel_group": None, "sequence_parallel": False, } tensor_parallel_size = parallel_state.get_tensor_model_parallel_world_size() if tensor_parallel_size > 1: kwargs["tensor_parallel_group"] = parallel_state.get_tensor_model_parallel_group() if isinstance(self.to_wrap, (te.Linear, te.LayerNormLinear)): kwargs["tensor_parallel_mode"] = self.to_wrap.parallel_mode kwargs["sequence_parallel"] = self.to_wrap.sequence_parallel if kwargs["tensor_parallel_mode"] == "row": kwargs["in_features"] *= tensor_parallel_size elif kwargs["tensor_parallel_mode"] == "column": kwargs["out_features"] *= tensor_parallel_size # wgrad accumulation fusion accumulate_into_main_grad = False if isinstance(self.to_wrap, (te.Linear, te.LayerNormLinear)): accumulate_into_main_grad = self.to_wrap.fuse_wgrad_accumulation kwargs["accumulate_into_main_grad"] = accumulate_into_main_grad # Construct fused branches main_branch = self._make_main_branch(**kwargs) lora_branch = self._make_lora_branch(**kwargs) # Get submodule forward hooks forward_pre_hooks = [] forward_post_hooks = [] for submodule in self.modules(): for hook in submodule._forward_pre_hooks.values(): forward_pre_hooks.append((submodule, hook)) for hook in submodule._forward_hooks.values(): forward_post_hooks.append((submodule, hook)) # Attempt to emulate submodule forward hooks if needed # Note: Assume hooks do not interact with submodule inputs # or outputs since they are internal to the op fuser. if forward_pre_hooks: def forward_pre_hook(module, *_) -> None: for submodule, hook in forward_pre_hooks: # Assume that hook does not interact with # input hook(submodule, None) main_branch.register_forward_pre_hook(forward_pre_hook) if forward_post_hooks: def forward_post_hook(module, *_) -> None: for submodule, hook in forward_post_hooks: # Assume that hook does not interact with # input or output hook(submodule, None, None) lora_branch.register_forward_hook(forward_post_hook) return main_branch, lora_branch def _make_main_branch( self, *, in_features: int, out_features: int, tensor_parallel_mode: Optional[str], tensor_parallel_group: Optional[torch.distributed.ProcessGroup], sequence_parallel: bool, accumulate_into_main_grad: bool, ) -> te.ops.Sequential: """Construct fused module for main branch (norm + fork + linear)""" # Check wrapped linear class if not isinstance(self.to_wrap, (te.Linear, te.LayerNormLinear, torch.nn.Linear)): raise ValueError(f"Unsupported class for wrapped linear ({self.to_wrap.__class__.__name__})") # Ops in main branch main_branch = te.ops.Sequential() # Norm op if isinstance(self.to_wrap, te.LayerNormLinear): norm_type = self.to_wrap.normalization kwargs = { "eps": self.to_wrap.eps, "device": "meta", "dtype": self.to_wrap.layer_norm_weight.dtype, "zero_centered_gamma": self.to_wrap.zero_centered_gamma, } op = None if norm_type == "LayerNorm": op = te.ops.LayerNorm(in_features, **kwargs) op.weight = self.to_wrap.layer_norm_weight op.bias = self.to_wrap.layer_norm_bias elif norm_type == "RMSNorm": op = te.ops.RMSNorm(in_features, **kwargs) op.weight = self.to_wrap.layer_norm_weight else: raise ValueError(f"Unsupported normalization ({norm_type})") main_branch.append(op) main_branch.append(te.ops.Quantize(forward=True, backward=False)) # Fork to LoRA branch # Note: GEMM with beta=1 in backward pass main_branch.append(te.ops.MakeExtraOutput(in_place=True)) # Linear op weight = self.to_wrap.weight bias = self.to_wrap.bias if isinstance(bias, torch.Tensor) and bias.numel() == 0: bias = None op = te.ops.Linear( in_features, out_features, bias=bias is not None, device="meta", dtype=weight.dtype, tensor_parallel_mode=tensor_parallel_mode, tensor_parallel_group=tensor_parallel_group, sequence_parallel=sequence_parallel, accumulate_into_main_grad=accumulate_into_main_grad, ) op.weight = weight op.bias = bias main_branch.append(op) return main_branch def _make_lora_branch( self, *, in_features: int, out_features: int, tensor_parallel_mode: Optional[str], tensor_parallel_group: Optional[torch.distributed.ProcessGroup], sequence_parallel: bool, accumulate_into_main_grad: bool, ) -> te.ops.Sequential: """Construct fused module for LoRA branch (lora_a + lora_b + add)""" from nemo.collections.llm.peft.utils import ParallelLinearAdapter # Extract params from LoRA adapter lora_a_weight = None lora_b_weight = None lora_dim = None dropout = 0 dropout_position = None scale = None if isinstance(self.adapter, (LinearAdapter, TELinearAdapter)): lora_a_weight = self.adapter.lora_a.weight lora_b_weight = self.adapter.lora_b.weight lora_dim = lora_b_weight.size(1) dropout = self.adapter.dropout.p dropout_position = self.adapter.dropout_position scale = self.adapter.scale elif isinstance(self.adapter, ParallelLinearAdapter): lora_a_weight = self.adapter.linear_in.weight lora_b_weight = self.adapter.linear_out.weight lora_dim = lora_b_weight.size(1) if self.adapter.dropout is not None: dropout = self.adapter.dropout.p dropout_position = self.adapter.dropout_position scale = self.adapter.alpha / self.adapter.dim else: raise ValueError(f"Unsupported class for LoRA adapter ({self.adapter.__class__.__name__})") # Ops in LoRA branch lora_branch = te.ops.Sequential() # LoRA pre-processing if dropout > 0 and dropout_position == "pre": lora_branch.append(te.ops.Dropout(dropout)) # LoRA A linear op op = te.ops.Linear( in_features, lora_dim, bias=False, device="meta", dtype=lora_a_weight.dtype, tensor_parallel_mode=tensor_parallel_mode, tensor_parallel_group=tensor_parallel_group, sequence_parallel=sequence_parallel, accumulate_into_main_grad=accumulate_into_main_grad, ) op.weight = lora_a_weight lora_branch.append(op) # LoRA B linear op if tensor_parallel_mode == "column": # All-gather along dim -1 raise NotImplementedError("Column tensor parallelism is not yet supported") op = te.ops.Linear( lora_dim, out_features, bias=False, device="meta", dtype=lora_b_weight.dtype, tensor_parallel_mode=None if tensor_parallel_mode is None else "column", tensor_parallel_group=tensor_parallel_group, sequence_parallel=False, accumulate_into_main_grad=accumulate_into_main_grad, ) op.weight = lora_b_weight lora_branch.append(op) # LoRA post-processing if scale != 1: lora_branch.append(te.ops.ConstantScale(scale)) if dropout > 0 and dropout_position == "post": lora_branch.append(te.ops.Dropout(dropout)) if tensor_parallel_mode == "row": # All-gather along dim -1 raise NotImplementedError("Row tensor parallelism is not yet supported") # Add with main branch # Note: GEMM with beta=1 in forward pass lora_branch.append(te.ops.AddExtraInput(in_place=True)) return lora_branch def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]: # pylint: disable=C0115,C0116 # Construct fused impl if needed # Note: We initialize during the first forward pass in # case the params are modified after the constructor. # Note: The fused impl is stored in a tuple to avoid # registering submodules. if self._fused_branches is None: self._fused_branches = self._make_fused_branches() # Apply fused impl main_branch, lora_branch = self._fused_branches linear_output, linear_input = main_branch(x) with te.fp8_autocast(enabled=False): out = lora_branch(linear_input, linear_output) return out, None if HAVE_TE: class TELinearAdapter(te.Linear): """ TELinear + LoRA, maintains ckpts structrue (i.e. Linear's weight/bias remain at the same FQN) The _init_wrapper and _forward methods provide the LoRA functionality. We want to be able to use those inside LinearAdapter but also for monkey-patching modules, without repeating the same code -> therefore those are decorated with @staticmethod. Args: orig_linear (nn.Module): the linear module to augment. dim (int): lora's dim in_features -> dim -> out_features. alpha (int): lora's scaling alpha. dropout (float): dropout prob (default: 0.0). dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post) lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform']) lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they are quantized weights (e.g. 4bit) needs to be specified explicitly. """ def __init__( self, orig_linear, dim=8, alpha=32, dropout=0.0, dropout_position='post', lora_A_init_method='xavier', lora_dtype=None, ): assert orig_linear.__class__ == te.Linear # TELinear has bias set to empty tensor has_bias = orig_linear.bias is not None and orig_linear.bias.shape[0] != 0 super(TELinearAdapter, self).__init__( in_features=orig_linear.in_features, out_features=orig_linear.out_features, bias=has_bias, device=orig_linear.weight.device, params_dtype=orig_linear.weight.dtype, ) # copy weights self.weight.data.copy_(orig_linear.weight.data) if has_bias: self.bias.data.copy_(orig_linear.bias.data) # initialize the adapter TELinearAdapter._init_adapter( self, dim=dim, alpha=alpha, dropout=dropout, dropout_position=dropout_position, lora_A_init_method=lora_A_init_method, lora_dtype=lora_dtype, ) @torch.no_grad @staticmethod def _init_adapter( obj, dim=8, alpha=32, dropout=0.0, dropout_position='post', lora_A_init_method='xavier', lora_dtype=None, ): """Adds LoRA weights to obj. The obj is either a LinearAdapter or an nn.Module (when monkey-patching). Args: obj (LinearAdapter | nn.Module): input module to adapt. dim (int): lora's dim in_features -> dim -> out_features. alpha (int): lora's scaling alpha. dropout (float): dropout prob (default: 0.0). dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post) lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform']) lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they are quantized weights (e.g. 4bit) needs to be specified explicitly. """ obj.dim = dim obj.scale = alpha / dim # Freezer device = obj.weight.device obj.weight.requires_grad = False if obj.bias is not None: obj.bias.requires_grad = False in_features = obj.in_features out_features = obj.out_features dtype = lora_dtype or obj.weight.dtype obj.lora_a = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device) obj.lora_b = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device) if lora_A_init_method == 'xavier': torch.nn.init.uniform_(obj.lora_a.weight.data) else: nn.init.kaiming_uniform_(obj.lora_a.weight.data, a=math.sqrt(5)) obj.lora_b.weight.data.fill_(0) obj.dropout = nn.Dropout(p=dropout) assert dropout_position in ['pre', 'post'], dropout_position obj.dropout_position = dropout_position def forward(self, x): # pylint: disable=C0115,C0116 res = super(TELinearAdapter, self).forward(x) if self.dropout_position == 'pre': x = self.dropout(x) # LoRA fwd is performed in original precision regardless of FP8 enabled lora_res = self.lora_b(self.lora_a(x)) lora_res = lora_res * self.scale if self.dropout_position == 'post': lora_res = self.dropout(lora_res) return res + lora_res class LinearAdapter(nn.Linear): """ Linear + LoRA, maintains ckpts structrue (i.e. Linear's weight/bias remain at the same FQN) The _init_wrapper and _forward methods provide the LoRA functionality. We want to be able to use those inside LinearAdapter but also for monkey-patching modules, without repeating the same code -> therefore those are decorated with @staticmethod. Args: orig_linear (nn.Module): the linear module to augment. dim (int): lora's dim in_features -> dim -> out_features. alpha (int): lora's scaling alpha. dropout (float): dropout prob (default: 0.0). dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post) lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform']) lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they are quantized weights (e.g. 4bit) needs to be specified explicitly. """ def __init__( self, orig_linear, dim=8, alpha=32, dropout=0.0, dropout_position='post', lora_A_init_method='xavier', lora_dtype=None, ): assert isinstance(orig_linear, nn.Linear) super(LinearAdapter, self).__init__( in_features=orig_linear.in_features, out_features=orig_linear.out_features, bias=orig_linear.bias is not None, device=orig_linear.weight.device, dtype=orig_linear.weight.dtype, ) # copy weights self.weight.data.copy_(orig_linear.weight.data) if orig_linear.bias is not None: self.bias.data.copy_(orig_linear.bias.data) # initialize the adapte LinearAdapter._init_adapter( self, dim=dim, alpha=alpha, dropout=dropout, dropout_position=dropout_position, lora_A_init_method=lora_A_init_method, lora_dtype=lora_dtype, ) @torch.no_grad @staticmethod def _init_adapter( obj, dim=8, alpha=32, dropout=0.0, dropout_position='post', lora_A_init_method='xavier', lora_dtype=None, ): """Adds LoRA weights to obj. The obj is either a LinearAdapter or an nn.Module (when monkey-patching). Args: obj (LinearAdapter | nn.Module): input module to adapt. dim (int): lora's dim in_features -> dim -> out_features. alpha (int): lora's scaling alpha. dropout (float): dropout prob (default: 0.0). dropout_position (str): where to apply dropout rel. to lora (choices= ['pre', 'post'], default=post) lora_A_init_method (str): init method for lora_A (choices= ['xavier', 'uniform']) lora_dtype (torch.dtype): weight's dtype, by default will use orig_linear's but if they are quantized weights (e.g. 4bit) needs to be specified explicitly. """ obj.dim = dim obj.scale = alpha / dim # Freezer device = obj.weight.device obj.weight.requires_grad = False if obj.bias is not None: obj.bias.requires_grad = False in_features = obj.in_features out_features = obj.out_features dtype = lora_dtype or obj.weight.dtype obj.lora_a = nn.Linear(in_features, dim, bias=False, dtype=dtype, device=device) obj.lora_b = nn.Linear(dim, out_features, bias=False, dtype=dtype, device=device) if lora_A_init_method == 'xavier': torch.nn.init.uniform_(obj.lora_a.weight.data) else: nn.init.kaiming_uniform_(obj.lora_a.weight.data, a=math.sqrt(5)) obj.lora_b.weight.data.fill_(0) obj.dropout = nn.Dropout(p=dropout) assert dropout_position in ['pre', 'post'], dropout_position obj.dropout_position = dropout_position def forward(self, x): # pylint: disable=C0115,C0116 # If LinearAdapter is used to monkey-patch a nn.Linear module, we want to use nn.Linear's # forward in the case where it uses quantized weights. We store a reference to nn.Linear's # forward in `super_fwd` attribute. If the attribute does not exist we do the usual linear. if (fwd := getattr(self, 'super_fwd', None)) is not None: assert fwd != self.forward res = fwd(x) else: res = F.linear(x, self.weight, self.bias) if self.dropout_position == 'pre': x = self.dropout(x) lora_res = self.lora_b(self.lora_a(x)) lora_res = lora_res * self.scale if self.dropout_position == 'post': lora_res = self.dropout(lora_res) return res + lora_res def patch_linear_module( orig_linear, dim=8, alpha=32, dropout=0.0, dropout_position='post', lora_A_init_method='xavier', lora_dtype=None, ): """Monkey-patches a nn.Linear (orig_linear param) to be a LinearAdapter, for all purposes think of this function as replacing a nn.Linear with a LinearAdapter defined above. The orig_linear might not contain valid weights, for example, the given orig_linear was initialized within a context-manager that uses a "meta" device. Therefore, we cannot copy the weight/bias from the orig_linear to the LinearAdapter, since those have not been allocated, To circumvent this scenario, LinearAdapter's additional functionality (_init_adapter, _forward) is based on static functions, so that we can use them for patching or when allocating a new LinearAdapter object. Args: orig_linear (nn.Linear): the module we add adapter to. dim (int, optional): Lora dim. Defaults to 8. alpha (int, optional): Lora alpha scale. Defaults to 32. dropout (float, optional): dropout prob. Defaults to 0.0. dropout_position (str, optional): location to apply dropout wrt lora. Defaults to 'post' (choices: 'pre', 'post'). lora_A_init_method (str, optional): lora_a init method. Defaults to 'xavier'. lora_dtype (_type_, optional): Lora weights' dtype. By default will use orig_linear's dtype but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must specify the dtype manually. Defaults to None. Returns: (nn.Module): the monkey-patched (nn.Linear + LoRA) nn.Module """ assert isinstance(orig_linear, nn.Linear) or orig_linear.__class__ == te.Linear assert not hasattr(orig_linear, 'super_fwd'), orig_linear.super_fwd if isinstance(orig_linear, nn.Linear): LinearAdapter._init_adapter(orig_linear, dim, alpha, dropout, dropout_position, lora_A_init_method, lora_dtype) cls = orig_linear.__class__ new_cls = type('PatchedLinearAdapter', (LinearAdapter, cls), {}) elif orig_linear.__class__ == te.Linear: TELinearAdapter._init_adapter( orig_linear, dim, alpha, dropout, dropout_position, lora_A_init_method, lora_dtype ) cls = orig_linear.__class__ new_cls = type('PatchedTELinearAdapter', (TELinearAdapter, cls), {}) else: raise NotImplementedError("Expected isinstance(orig_linear, (nn.Linear, te.Linear))") # If the model uses quantized weights, we want to use orig_linear's forward if ( getattr(orig_linear, 'quant_state', None) is not None and orig_linear.quant_state.__class__ == bitsandbytes.functional.QuantState ): orig_linear.super_fwd = orig_linear.forward orig_linear.__class__ = new_cls return orig_linear @dataclass class LoRA(PEFT, ModuleMatcher): """ Implements the LoRA (Low-Rank Adaptation) module for parameter-efficient fine-tuning. LoRA uses a low-rank projection to adapt the weights of a pre-trained model to a new downstream task. This class facilitates the application of LoRA to specific modules within the model architecture. Args: target_modules (list[str], optional): A list of module names to apply LoRA to. Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections in self-attention. - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention. - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP. - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. Target modules can also contain wildcards. For example, you can specify target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers. exclude_modules (list[str], optional): A list of module names not to apply LoRa to. It will match all nn.Linear & nn.Linear-adjacent modules whose name does not match any string in exclude_modules. If used, will require target_modules to be empty list or None. dim (int): Dimension of the low-rank projection space. Defaults to 32. alpha (int): Weighting factor for the low-rank projection. Defaults to 32. dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0. dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre'. a2a_experimental (bool): Enables the experimental All-to-All (A2A) communication strategy. Defaults to False. dropout_recompute (bool): Enables dropout recompute using Thunder JIT compilation. When True, applies thunder.jit() to the dropout layer for memory-efficient training by recomputing dropout activations during backward pass instead of storing them. lora_dtype (torch.dtype): Parameter data type for LoRA weights. Default None (will use model's dtype). Example: -------- >>> from nemo.collections import llm >>> lora = llm.peft.LoRA(target_modules=['linear_qkv', 'linear_proj'], dim=32) >>> model = llm.Mistral7BModel(model_transform=lora) >>> # (set up trainer and data) >>> trainer.fit(model, data) References: ----------- Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., & Chen, W. (2021). LoRA: Low-Rank Adaptation of Large Language Models. arXiv preprint arXiv:2106.09685. https://arxiv.org/abs/2106.09685 ) """ target_modules: list[str] = field( default_factory=lambda: ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'] ) dim: int = 32 alpha: int = 32 dropout: float = 0.0 dropout_position: Literal['pre', 'post'] = 'pre' lora_A_init_method: str = "xavier" lora_B_init_method: str = "zero" a2a_experimental: bool = False lora_dtype: torch.dtype = None dropout_recompute: bool = False def transform(self, m: nn.Module, name=None, prefix=None): """ Applies LoRA to a specific module within the model architecture. Args: m (nn.Module): The module to apply LoRA to. name (str, optional): Name of the module (if applicable). Defaults to None. prefix (str, optional): Prefix for the module name (if applicable). Defaults to None. Returns: nn.Module: The modified module with LoRA applied, or the original module if not a target. """ from nemo.collections.llm.peft.utils import ParallelLinearAdapter if (ans := self.match(m, name, prefix)) is not None: (match, full_name) = ans if isinstance(m, nn.Linear) or m.__class__ == te.Linear: # Will use the `patch_linear_module` function if: # - is FSDP v1 # - is DTensor (has _local_tensor attribute) # - has quant_state attribute if ( self._add_via_setattr or hasattr(m.weight.data, '_local_tensor') or ( getattr(m, 'quant_state', None) is not None and m.quant_state.__class__ == bitsandbytes.functional.QuantState ) ): lora_cls = patch_linear_module elif HAVE_TE and m.__class__ == te.Linear: lora_cls = TELinearAdapter else: lora_cls = LinearAdapter # Construct LoRA module return lora_cls( m, dim=self.dim, alpha=self.alpha, dropout=self.dropout, lora_A_init_method=self.lora_A_init_method, lora_dtype=self.lora_dtype, ) input_is_parallel, in_features, out_features, disable_sp_comm, base_linear_is_parallel = ( get_adapter_attributes_from_linear(m) ) enable_op_fuser = ( HAVE_TE_FUSED_LORA and hasattr(m, "config") and getattr(m.config, "use_transformer_engine_op_fuser", False) # TP not yet supported and parallel_state.get_tensor_model_parallel_world_size() == 1 ) logging.info(f"Adding lora to: {full_name}") adapter = ParallelLinearAdapter( in_features, out_features, self.dim, base_linear_name=full_name, activation='identity', norm_type=None, column_init_method=self.lora_A_init_method, row_init_method=self.lora_B_init_method, gather_output=False, input_is_parallel=input_is_parallel, dropout=self.dropout, dropout_position=self.dropout_position, model_parallel_config=getattr(m, "config", None), alpha=self.alpha, is_expert=is_expert_linear(full_name), a2a_experimental=self.a2a_experimental, disable_sequence_parallel_comm=disable_sp_comm, dropout_recompute=self.dropout_recompute, base_linear_is_parallel=base_linear_is_parallel, ) if enable_op_fuser: return TEFusedLoRALinear(m, adapter) else: return LoRALinear(m, adapter) return m class LoRAMerge(PEFT): """ Implements the LoRA weight merge for parameter-efficient fine-tuning. Example: -------- >>> from nemo.collections.llm.peft.lora import LoRAMerge >>> lora_merge = LoRAMerge() >>> merged_model = lora_merge(trainer.strategy.megatron_parallel) """ @torch.no_grad() def transform(self, m: nn.Module, name=None, prefix=None): """ Merges the LoRA adapter with the base model weights. Args: m (nn.Module): The module to apply LoRA merge to. name (str, optional): Name of the module to merge. Defaults to None. prefix (str, optional): Prefix for the module name. Defaults to None. Returns: nn.Module: The modified module with the LoRA adapter merged into the base model weights. """ if not isinstance(m, LoRALinear): return m logging.info(f'merging {(prefix if prefix else "") + "." + (name if name else "")}') lora_weight = m.adapter.alpha / m.adapter.dim * m.adapter.linear_out.weight @ m.adapter.linear_in.weight if hasattr(m.to_wrap, "weight"): base_weight = m.to_wrap.weight merged_weight = base_weight + lora_weight.to(base_weight.device) m.to_wrap.weight.data = merged_weight else: # TE Grouped Linear for i in range(m.to_wrap.num_gemms): base_weight = getattr(m.to_wrap, f"weight{i}") merged_weight = base_weight + lora_weight.to(base_weight.device) getattr(m.to_wrap, f"weight{i}").data = merged_weight return m