# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Literal, Tuple import torch import torch.nn.functional as F from torch import Tensor, nn from torch.distributed import all_gather as all_gather_no_backprop from torch.distributed.nn.functional import all_gather as all_gather_with_backprop from nemo.lightning.megatron_parallel import MaskedTokenLossReduction, MegatronLossReduction class BERTLossReduction(MegatronLossReduction): """Bert Loss Function. when add_sop_loss = False, only calculate Masked token loss. """ def __init__(self, validation_step: bool = False, val_drop_last: bool = True, add_sop_loss: bool = True) -> None: super().__init__() self.validation_step = validation_step self.val_drop_last = val_drop_last self.add_sop_loss = add_sop_loss if not add_sop_loss: # BERTLoss would act like MaskedTokenLossReduction when only use MLM loss self.mlm = MaskedTokenLossReduction(validation_step, val_drop_last) def forward( self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Perform Loss calculation on batch. Currently, Context parallelism is not supported for SOP loss. """ # Update loss_mask to batch. # Model forward did no update to loss_mask, but for unknown reason loss_mask can get lost (to None) # in 'batch' during update. We use the original loss_mask in the dataloader as the ground truth. batch['loss_mask'] = forward_out['loss_mask'] if not self.add_sop_loss: return self.mlm.forward(batch, forward_out['lm_loss']) from megatron.core import parallel_state lm_loss_, sop_logits = forward_out['lm_loss'], forward_out['binary_logits'] assert sop_logits is not None, ( 'Attempting to calculate Sentence Order Prediction Loss but SOP logits ' 'are not provideds, Please Make sure you have added binary head.' ) cp_size = parallel_state.get_context_parallel_world_size() if cp_size == 1: sop_loss_for_ub = sentence_order_prediction_loss(sop_logits, batch["is_random"]) lm_loss_for_ub = masked_token_with_zero(lm_loss_, batch["loss_mask"]) else: raise NotImplementedError('CP is not supported for SOP loss yet') loss_for_ub = sop_loss_for_ub + lm_loss_for_ub reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) return loss_for_ub, {"avg": reduced_loss} def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: """Taken from: https://github.com/NVIDIA/NeMo/blob/main /nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" if losses_reduced_per_micro_batch: if "avg" in losses_reduced_per_micro_batch[0]: # legacy behavior, average over the number of microbatches avg = [x["avg"] for x in losses_reduced_per_micro_batch] loss = torch.cat(avg).mean() return loss from megatron.core import parallel_state loss_sum_and_ub_size = [ x["loss_sum_and_ub_size"] for x in losses_reduced_per_micro_batch if x["loss_sum_and_ub_size"][1] > 0 ] loss = ( torch.vstack(loss_sum_and_ub_size).sum(dim=0) if len(loss_sum_and_ub_size) > 0 else torch.tensor([0.0, 0.0], device=torch.cuda.current_device()) ) torch.distributed.all_reduce( loss, group=parallel_state.get_data_parallel_group(with_context_parallel=True), ) # average over the total number of tokens across the global batch. loss = loss[0] / loss[1] return loss return torch.tensor(0.0, device=torch.cuda.current_device()) class HardNegativeRankingLoss(MegatronLossReduction): """ This loss uses hard-negative samples. The difference of this loss to the default MultipleNegativesRankingLoss from Sentence Transformers is that the latter shares the hard negatives as negatives for all examples, whereas this loss uses hard negatives exclusively for the example they are associated. """ def __init__( self, validation_step: bool = False, val_drop_last: bool = True, num_hard_negatives: int = 1, scale: float = 50, label_smoothing: float = 0.0, ) -> None: super().__init__() self.validation_step = validation_step self.val_drop_last = val_drop_last self.num_hard_negatives = num_hard_negatives self.scale = scale self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing) def forward( self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: from megatron.core import parallel_state cp_size = parallel_state.get_context_parallel_world_size() if cp_size != 1: raise NotImplementedError(f'CP is not supported for {self.__class__} yet.') num_tensors_per_example = 2 + self.num_hard_negatives # 1 query, 1 pos, num_hard_negatives negs current_train_n_passages = 1 + self.num_hard_negatives batch_size = forward_out.shape[0] // num_tensors_per_example # Get Query, Key (Positives, Negatives) # forward_out was chunked [(q1, k1), (q2, k2), ...] chunks = forward_out.chunk(batch_size) query = torch.stack([item[0] for item in chunks]) key = torch.cat([item[1:] for item in chunks]) assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0]) assert key.shape[0] / query.shape[0] == current_train_n_passages, '{} / {} != {}'.format( key.shape[0], query.shape[0], current_train_n_passages ) query_shape = query.shape repeated_query = query.repeat(1, 1, current_train_n_passages).reshape( query_shape[0] * current_train_n_passages, query_shape[1] ) scores = torch.sum(repeated_query * key, dim=-1).reshape(query_shape[0], current_train_n_passages) labels = torch.zeros(query_shape[0], dtype=torch.long, device=query.device) scores *= self.scale ce_loss = self.cross_entropy_loss(scores, labels) reduced_loss = average_losses_across_data_parallel_group([ce_loss]) return ce_loss, {"avg": reduced_loss} def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: """Taken from: https://github.com/NVIDIA/NeMo/blob/main /nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" if losses_reduced_per_micro_batch: if "avg" in losses_reduced_per_micro_batch[0]: # legacy behavior, average over the number of microbatches avg = [x["avg"] for x in losses_reduced_per_micro_batch] loss = torch.cat(avg).mean() return loss from megatron.core import parallel_state loss_sum_and_ub_size = [ x["loss_sum_and_ub_size"] for x in losses_reduced_per_micro_batch if x["loss_sum_and_ub_size"][1] > 0 ] loss = ( torch.vstack(loss_sum_and_ub_size).sum(dim=0) if len(loss_sum_and_ub_size) > 0 else torch.tensor([0.0, 0.0], device=torch.cuda.current_device()) ) torch.distributed.all_reduce( loss, group=parallel_state.get_data_parallel_group(with_context_parallel=True), ) # average over the total number of tokens across the global batch. loss = loss[0] / loss[1] return loss return torch.tensor(0.0, device=torch.cuda.current_device()) class BERTInBatchExclusiveHardNegativesRankingLoss(MegatronLossReduction): """ This loss uses in-batch negative samples + hard-negative samples. The difference of this loss to the default MultipleNegativesRankingLoss from Sentence Transformers is that the latter shares the hard negatives as negatives for all examples, whereas this loss uses hard negatives exclusively for the example they are associated. This loss is also capable of using in-batch negatives from all ranks during training. """ def __init__( self, validation_step: bool = False, val_drop_last: bool = True, num_hard_negatives: int = 1, scale: float = 20, label_smoothing: float = 0.0, global_in_batch_negatives: bool = False, backprop_type: Literal["local", "global"] = 'local', ) -> None: super().__init__() self.validation_step = validation_step self.val_drop_last = val_drop_last self.num_hard_negatives = num_hard_negatives self.scale = scale self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing) self.global_in_batch_negatives = global_in_batch_negatives self.backprop_type = backprop_type def _gather_global_in_batch_representations(self, local_tensor): from megatron.core import parallel_state local_tensor = local_tensor.contiguous() if self.backprop_type == 'local': global_tensors = [ torch.zeros_like(local_tensor) for _ in range(parallel_state.get_data_parallel_world_size()) ] all_gather_no_backprop(global_tensors, local_tensor, group=parallel_state.get_data_parallel_group()) global_tensors[parallel_state.get_data_parallel_rank()] = local_tensor global_tensors = torch.cat(global_tensors, dim=0) else: global_tensors = all_gather_with_backprop(local_tensor) global_tensors = torch.cat(global_tensors, dim=0) return global_tensors def forward( self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: from megatron.core import parallel_state cp_size = parallel_state.get_context_parallel_world_size() if cp_size != 1: raise NotImplementedError(f'CP is not supported for {self.__class__} yet.') if self.global_in_batch_negatives and not self.validation_step: forward_out = self._gather_global_in_batch_representations(forward_out) num_tensors_per_example = 2 + self.num_hard_negatives batch_size = forward_out.shape[0] // num_tensors_per_example chunks = forward_out.chunk(batch_size) # Get Queries, Positives, Negatives queries = torch.stack([item[0] for item in chunks]) positives = torch.stack([item[1] for item in chunks]) hard_negs = [ torch.stack([item[i + 2] for item in chunks]) for i in range(self.num_hard_negatives) ] # List of length "num_negatives", each tensor of shape (bs, embedding_dim) # Calculate scores pos_in_batch_negs_scores = torch.mm( queries, positives.transpose(0, 1) # shape (bs, bs); each positive is negative for other queries. ) hard_negs_scores = ( torch.multiply( queries.unsqueeze(0).repeat(len(hard_negs), 1, 1), torch.stack(hard_negs), ) .sum(axis=-1) .T ) # shape = (bs, num_negatives); Hard negatives are not shared between queries. scores = torch.cat([pos_in_batch_negs_scores, hard_negs_scores], axis=1) scores = scores.clamp(-1.0, 1.0) scores *= self.scale labels = torch.tensor( range(len(scores)), dtype=torch.long, device=scores.device ) # Indices of the (query, positive) pairs ce_loss = self.cross_entropy_loss(scores, labels) reduced_loss = average_losses_across_data_parallel_group([ce_loss]) return ce_loss, {"avg": reduced_loss} def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: """Taken from: https://github.com/NVIDIA/NeMo/blob/main /nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" if losses_reduced_per_micro_batch: if "avg" in losses_reduced_per_micro_batch[0]: # legacy behavior, average over the number of microbatches avg = [x["avg"] for x in losses_reduced_per_micro_batch] loss = torch.cat(avg).mean() return loss from megatron.core import parallel_state loss_sum_and_ub_size = [ x["loss_sum_and_ub_size"] for x in losses_reduced_per_micro_batch if x["loss_sum_and_ub_size"][1] > 0 ] loss = ( torch.vstack(loss_sum_and_ub_size).sum(dim=0) if len(loss_sum_and_ub_size) > 0 else torch.tensor([0.0, 0.0], device=torch.cuda.current_device()) ) torch.distributed.all_reduce( loss, group=parallel_state.get_data_parallel_group(with_context_parallel=True), ) # average over the total number of tokens across the global batch. loss = loss[0] / loss[1] return loss return torch.tensor(0.0, device=torch.cuda.current_device()) def masked_token_with_zero(tensor: Tensor, mask: Tensor): """Calculate masked token loss with consideration of possible NaN. Sometimes when the number of tokens is very small, none of the tokens get masked for prediction. In that case loss mask is all zeros i.e Happens when the entire batch is masked out (Practically when MBS=1 or 2, and the number of tokens in each batch is < 7 ) """ losses = tensor.float() loss_mask = mask.float() if loss_mask.sum() == 0: loss = torch.sum(losses.view(-1)) * 0.0 else: loss = torch.sum(losses.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() return loss def sentence_order_prediction_loss(tensor: Tensor, sentence_order: Tensor): """Calculate sentence order prediction loss.""" losses = tensor.view(-1, 2).float() sentence_order = sentence_order.view(-1) loss = F.cross_entropy(losses, sentence_order, ignore_index=-1) return loss def average_losses_across_data_parallel_group(losses): """Reduce a tensor of losses across all GPUs.""" from megatron.core import parallel_state averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) torch.distributed.all_reduce(averaged_losses, group=parallel_state.get_data_parallel_group()) averaged_losses = averaged_losses / torch.distributed.get_world_size( group=parallel_state.get_data_parallel_group() ) return averaged_losses