Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Dataloaders.""" | |
| import abc | |
| import warnings | |
| from itertools import chain | |
| from typing import Optional, Tuple | |
| import torch | |
| from nemo.utils import logging | |
| from nemo.utils.decorators import experimental | |
| __all__ = [ | |
| "MegatronPretrainingBatchSampler", | |
| "MegatronPretrainingRandomBatchSampler", | |
| ] | |
| class BaseMegatronSampler: | |
| """ """ | |
| def __init__( | |
| self, | |
| total_samples: int, | |
| consumed_samples: int, | |
| micro_batch_size: int, | |
| data_parallel_rank: int, | |
| data_parallel_size: int, | |
| drop_last: bool = True, | |
| global_batch_size: Optional[int] = None, | |
| rampup_batch_size: Optional[list] = None, | |
| pad_samples_to_global_batch_size: Optional[bool] = False, | |
| ) -> None: | |
| # Sanity checks. | |
| if total_samples <= 0: | |
| raise RuntimeError("no sample to consume: {}".format(total_samples)) | |
| if micro_batch_size <= 0: | |
| raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}") | |
| if data_parallel_size <= 0: | |
| raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}") | |
| if data_parallel_rank >= data_parallel_size: | |
| raise RuntimeError( | |
| "data_parallel_rank should be smaller than data size, but {} >= {}".format( | |
| data_parallel_rank, data_parallel_size | |
| ) | |
| ) | |
| if global_batch_size is not None and rampup_batch_size is None: | |
| if global_batch_size % (micro_batch_size * data_parallel_size) != 0: | |
| raise RuntimeError( | |
| f"`global_batch_size` ({global_batch_size}) is not divisible by " | |
| f"`micro_batch_size ({micro_batch_size}) x data_parallel_size " | |
| f"({data_parallel_size})`" | |
| ) | |
| if pad_samples_to_global_batch_size and global_batch_size is None: | |
| raise RuntimeError( | |
| "`pad_samples_to_global_batch_size` can be `True` only when " | |
| "`global_batch_size` is set to an integer value" | |
| ) | |
| # Keep a copy of input params for later use. | |
| self.total_samples = total_samples | |
| self.consumed_samples = consumed_samples | |
| self.micro_batch_size = micro_batch_size | |
| self.data_parallel_rank = data_parallel_rank | |
| self.data_parallel_size = data_parallel_size | |
| self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size | |
| self.drop_last = drop_last | |
| self.global_batch_size = global_batch_size | |
| self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size | |
| logging.info( | |
| f'Instantiating MegatronPretrainingSampler with total_samples: {total_samples} ' | |
| f'and consumed_samples: {consumed_samples}' | |
| ) | |
| def __len__(self): | |
| num_available_samples: int = self.total_samples - self.consumed_samples | |
| if self.global_batch_size is not None: | |
| if self.drop_last: | |
| num_global_batches = num_available_samples // self.global_batch_size | |
| else: | |
| num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size | |
| # return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and | |
| # num of batches fetched (as training step fetches in terms of micro batches) | |
| return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size) | |
| else: | |
| return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1 | |
| def __iter__(self): ... | |
| class MegatronPretrainingSampler(BaseMegatronSampler): | |
| """ """ | |
| def get_start_end_idx(self): | |
| """ """ | |
| start_idx = self.data_parallel_rank * self.micro_batch_size | |
| end_idx = start_idx + self.micro_batch_size | |
| return start_idx, end_idx | |
| def _get_padding_indices(self, pad_samples_num): | |
| """ """ | |
| return range(-1, -pad_samples_num - 1, -1) | |
| def __iter__(self): | |
| batch = [] | |
| # Last batch will be dropped if drop_last is not set False | |
| indices = range(self.consumed_samples, self.total_samples) | |
| if (not self.drop_last) and self.pad_samples_to_global_batch_size: | |
| pad_samples_num = -len(indices) % self.global_batch_size | |
| pad_indices = self._get_padding_indices(pad_samples_num) | |
| indices = chain(indices, pad_indices) | |
| for idx in indices: | |
| batch.append(idx) | |
| if len(batch) == self.micro_batch_times_data_parallel_size: | |
| start_idx, end_idx = self.get_start_end_idx() | |
| yield batch[start_idx:end_idx] | |
| batch = [] | |
| # Check the last partial batch and see drop_last is set | |
| if len(batch) > 0 and not self.drop_last: | |
| assert ( | |
| not self.pad_samples_to_global_batch_size | |
| ), 'with pad_samples_to_global_batch_size all batches should be complete' | |
| start_idx, end_idx = self.get_start_end_idx() | |
| yield batch[start_idx:end_idx] | |
| class MegatronCorePretrainingSampler(MegatronPretrainingSampler): | |
| """ """ | |
| def _get_padding_indices(self, pad_samples_num): | |
| """ """ | |
| return [None] * pad_samples_num | |
| class MegatronPretrainingRandomSampler(BaseMegatronSampler): | |
| """ """ | |
| def __init__( | |
| self, | |
| total_samples: int, | |
| consumed_samples: int, | |
| micro_batch_size: int, | |
| data_parallel_rank: int, | |
| data_parallel_size: int, | |
| drop_last: bool = True, | |
| global_batch_size: Optional[int] = None, | |
| pad_samples_to_global_batch_size: Optional[bool] = False, | |
| seed: int = 0, | |
| ) -> None: | |
| super().__init__( | |
| total_samples=total_samples, | |
| consumed_samples=consumed_samples, | |
| micro_batch_size=micro_batch_size, | |
| data_parallel_rank=data_parallel_rank, | |
| data_parallel_size=data_parallel_size, | |
| drop_last=drop_last, | |
| global_batch_size=global_batch_size, | |
| pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, | |
| ) | |
| assert ( | |
| not pad_samples_to_global_batch_size | |
| ), "`MegatronPretrainingRandomSampler` does not support sample padding" | |
| if (not drop_last) and self.micro_batch_times_data_parallel_size > 1: | |
| raise RuntimeError( | |
| "`MegatronPretrainingRandomSampler` does not support drop_last=False when \ | |
| micro_batch_size * data_parallel_size > 1. Please reduce your MBS and data parallelism to 1 \ | |
| if you want to use drop_last=False, or switch to drop_last=True to avoid this error" | |
| ) | |
| self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size | |
| self.seed = seed | |
| def __len__(self): | |
| active_total_samples = self.total_samples - (self.last_batch_size if self.drop_last else 0) | |
| num_available_samples = active_total_samples - self.consumed_samples % active_total_samples | |
| if self.global_batch_size is not None: | |
| if self.drop_last: | |
| num_global_batches = num_available_samples // self.global_batch_size | |
| else: | |
| num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size | |
| # return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and | |
| # num of batches fetched (as training step fetches in terms of micro batches) | |
| return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size) | |
| else: | |
| if self.drop_last: | |
| return num_available_samples // self.micro_batch_times_data_parallel_size | |
| else: | |
| return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size | |
| def __iter__(self): | |
| active_total_samples = self.total_samples - self.last_batch_size | |
| self.epoch = self.consumed_samples // active_total_samples | |
| current_epoch_samples = self.consumed_samples % active_total_samples | |
| assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 | |
| # data sharding and random sampling | |
| bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size | |
| bucket_offset = current_epoch_samples // self.data_parallel_size | |
| start_idx = self.data_parallel_rank * bucket_size | |
| g = torch.Generator() | |
| g.manual_seed(self.seed + self.epoch) | |
| random_idx = torch.randperm(bucket_size, generator=g).tolist() | |
| idx_range = [start_idx + x for x in random_idx[bucket_offset:]] | |
| batch = [] | |
| # Last batch if not complete will be dropped. | |
| for idx in idx_range: | |
| batch.append(idx) | |
| if len(batch) == self.micro_batch_size: | |
| self.consumed_samples += self.micro_batch_times_data_parallel_size | |
| yield batch | |
| batch = [] | |
| # Check the last partial batch and see drop_last is set | |
| if len(batch) > 0 and not self.drop_last: | |
| yield batch | |
| class BaseMegatronBatchSampler: | |
| """Megatron style BatchSampler. | |
| Let mbs, gbs, tp, pp, and dp stand for "micro batch size", "global batch size", | |
| "tensor model parallel world size", "pipeline model parallel world size", and | |
| "data parallel world size", the number of micro batches (hereafter, nmb) is defined as | |
| :math:`nmb = gbs \\div (mbs \\times dp)`. | |
| See `apex/transformer/microbatches.py#L91-L98 <https://github.com/NVIDIA/apex/blob/ | |
| 44c3043685b6115e7b81b3458a6c76601b1e55b4/apex/transformer/microbatches.py#L91-L98>`_ | |
| for the initial settings of the number of micro batches and | |
| `apex/transformer/microbatches.py#L160-L177 <https://github.com/NVIDIA/apex/blob/ | |
| 44c3043685b6115e7b81b3458a6c76601b1e55b4/apex/transformer/microbatches.py#L160-L177>_`. | |
| for warming up of global batch size. | |
| e.g.) `(mbs, gbs, tp, pp, dp) = (1, 16, 1, 1, 2)`, then the number of micro batches is | |
| :math:`gbs \\div (mbs \\times dp) = 16 \\div (1 \\times 2) = 8`. | |
| In this case, an instance of Megatron Batch Sampler on each data parallel rank is expected | |
| returns :math:`nmb \\times mbs = 8` indices. | |
| """ | |
| _global_batch_size: int | |
| _num_micro_batches: int | |
| _global_batch_size_on_this_data_parallel_rank: int | |
| def __init__( | |
| self, | |
| total_samples: int, | |
| consumed_samples: int, | |
| micro_batch_size: int, | |
| global_batch_size: int, | |
| data_parallel_rank: int, | |
| data_parallel_size: int, | |
| drop_last: bool, | |
| pad_samples_to_global_batch_size=False, | |
| ) -> None: | |
| """Constructor of Megatron-LM style Batch Sampler. | |
| Args: | |
| total_samples: The size of dataset. | |
| consumed_samples: The number of samples that have been used. | |
| micro_batch_size: The size of each micro batch. | |
| global_batch_size: The size of global batch. | |
| data_parallel_rank: The value you can obtain via | |
| `parallel_state.get_data_parallel_rank()` of megatron.core. | |
| data_parallel_size: The value you can obtain via | |
| `parallel_state.get_data_parallel_world_size()` of megatron.core. | |
| """ | |
| # Sanity checks. | |
| if total_samples <= 0: | |
| raise RuntimeError("no sample to consume: {}".format(total_samples)) | |
| if micro_batch_size <= 0: | |
| raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}") | |
| if data_parallel_size <= 0: | |
| raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}") | |
| if data_parallel_rank >= data_parallel_size: | |
| raise RuntimeError( | |
| "data_parallel_rank should be smaller than data size, but {} >= {}".format( | |
| data_parallel_rank, data_parallel_size | |
| ) | |
| ) | |
| # Keep a copy of input params for later use. | |
| self.total_samples: int = total_samples | |
| self.consumed_samples: int = consumed_samples | |
| self.micro_batch_size: int = micro_batch_size | |
| self.data_parallel_rank: int = data_parallel_rank | |
| self.data_parallel_size: int = data_parallel_size | |
| self.drop_last: bool = drop_last | |
| self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size | |
| self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size | |
| self.update_global_batch_size(global_batch_size) | |
| def update_global_batch_size(self, new_global_batch_size: int) -> None: | |
| """Update the global batch size.""" | |
| self._global_batch_size = new_global_batch_size | |
| if self._global_batch_size % self.micro_batch_times_data_parallel_size != 0: | |
| raise RuntimeError( | |
| f"`global_batch_size` ({self._global_batch_size}) is not divisible by " | |
| f"`micro_batch_size ({self.micro_batch_size}) x data_parallel_size " | |
| f"({self.data_parallel_size})`" | |
| ) | |
| self._num_micro_batches = self._global_batch_size // self.micro_batch_times_data_parallel_size | |
| self._global_batch_size_on_this_data_parallel_rank = self._num_micro_batches * self.micro_batch_size | |
| def global_batch_size(self) -> int: | |
| """ """ | |
| return self._global_batch_size | |
| def global_batch_size(self, new_global_batch_size: int) -> None: | |
| """ """ | |
| warnings.warn("`self.update_global_batch_size(new_global_batch_size)` is recommended.") | |
| self.update_global_batch_size(new_global_batch_size=new_global_batch_size) | |
| def __len__(self) -> int: | |
| """Length of Batch Sampler. | |
| ..note:: | |
| When `rampup_batch_size` is enabled, the return value can be not exactly precise. | |
| """ | |
| num_available_samples: int = self.total_samples - self.consumed_samples % self.total_samples | |
| if self.drop_last: | |
| return num_available_samples // self.global_batch_size | |
| else: | |
| return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size | |
| def __iter__(self): ... | |
| class MegatronPretrainingBatchSampler(BaseMegatronBatchSampler): | |
| """ """ | |
| def get_start_end_idx(self) -> Tuple[int, int]: | |
| """ """ | |
| start_idx = self.data_parallel_rank * self._global_batch_size_on_this_data_parallel_rank | |
| end_idx = start_idx + self._global_batch_size_on_this_data_parallel_rank | |
| return start_idx, end_idx | |
| def __iter__(self): | |
| batch = [] | |
| # Last batch will be dropped if drop_last is not set False | |
| for idx in range(self.consumed_samples % self.total_samples, self.total_samples): | |
| batch.append(idx) | |
| if len(batch) == self._global_batch_size: | |
| # start_idx, end_idx = self.get_start_end_idx() | |
| indices = [ | |
| batch[i] | |
| for i in range( | |
| self.data_parallel_rank, | |
| self._global_batch_size, | |
| self.data_parallel_size, | |
| ) | |
| ] | |
| assert len(indices) == self._global_batch_size_on_this_data_parallel_rank | |
| yield indices | |
| # yield batch[start_idx:end_idx] | |
| batch = [] | |
| # Check the last partial batch and see drop_last is set | |
| if len(batch) > 0 and not self.drop_last: | |
| # start_idx, end_idx = self.get_start_end_idx() | |
| indices = [batch[i] for i in range(self.data_parallel_rank, len(batch), self.data_parallel_size)] | |
| if self.pad_samples_to_global_batch_size: | |
| num_pad = self._global_batch_size // self.data_parallel_size - len(indices) | |
| indices = indices + [-1] * num_pad | |
| yield indices | |
| class MegatronPretrainingRandomBatchSampler(BaseMegatronBatchSampler): | |
| """ """ | |
| # NOTE (mkozuki): [[Argument of `dataset` and `data_sharding`]] | |
| # From the commit below, it seems like `dataset` argument and `data_sharding` argument | |
| # are necessary for ViT training. However, to keep this simple, | |
| # I omit those two arguments. | |
| # commit: https://github.com/NVIDIA/Megatron-LM/commit/7a77abd9b6267dc0020a60b424b4748fc22790bb | |
| # | |
| # NOTE (degert): I have re-written this class somewhat to give the length correctly when consumed_samples | |
| # are larger than total_samples, which happens with epochs > 1 training when using this Sampler | |
| # I have also added an explicit seed which allows us to remove Dataset-side shuffling in Nemo-Aligner | |
| # | |
| # This class does not currently work with pad_samples_to_global_batch_size=True | |
| def __init__( | |
| self, | |
| total_samples: int, | |
| consumed_samples: int, | |
| micro_batch_size: int, | |
| global_batch_size: int, | |
| data_parallel_rank: int, | |
| data_parallel_size: int, | |
| drop_last: bool, | |
| pad_samples_to_global_batch_size: bool = False, | |
| seed: int = 0, | |
| ) -> None: | |
| super().__init__( | |
| total_samples=total_samples, | |
| consumed_samples=consumed_samples, | |
| micro_batch_size=micro_batch_size, | |
| data_parallel_rank=data_parallel_rank, | |
| data_parallel_size=data_parallel_size, | |
| drop_last=drop_last, | |
| global_batch_size=global_batch_size, | |
| pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, | |
| ) | |
| assert ( | |
| not pad_samples_to_global_batch_size | |
| ), "`MegatronPretrainingRandomBatchSampler` does not support sample padding" | |
| if (not drop_last) and self.micro_batch_times_data_parallel_size > 1: | |
| raise RuntimeError( | |
| "`MegatronPretrainingRandomBatchSampler` does not support drop_last=False \ | |
| when micro_batch_size * data_parallel_size > 1. Please reduce your MBS and data parallelism to 1 \ | |
| if you want to use drop_last=False, or switch to drop_last=True to avoid this error" | |
| ) | |
| self.last_batch_size = self.total_samples % self._global_batch_size | |
| self.seed = seed | |
| def __len__(self) -> int: | |
| """Length of Random Batch Sampler. | |
| ..note:: | |
| When `rampup_batch_size` is enabled, the return value can be not exactly precise. | |
| """ | |
| active_total_samples = self.total_samples - (self.last_batch_size if self.drop_last else 0) | |
| num_available_samples = active_total_samples - self.consumed_samples % active_total_samples | |
| if self.drop_last: | |
| return num_available_samples // self.global_batch_size | |
| else: | |
| return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size | |
| def __iter__(self): | |
| active_total_samples = self.total_samples - self.last_batch_size | |
| self.epoch = self.consumed_samples // active_total_samples | |
| current_epoch_samples = self.consumed_samples % active_total_samples | |
| assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 | |
| # data sharding and random sampling | |
| bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size | |
| bucket_offset = current_epoch_samples // self.data_parallel_size | |
| start_idx = self.data_parallel_rank * bucket_size | |
| g = torch.Generator() | |
| g.manual_seed(self.seed + self.epoch) | |
| random_idx = torch.randperm(bucket_size, generator=g).tolist() | |
| idx_range = [start_idx + x for x in random_idx[bucket_offset:]] | |
| batch = [] | |
| # Last batch if not complete will be dropped. | |
| for idx in idx_range: | |
| batch.append(idx) | |
| if len(batch) == self._global_batch_size_on_this_data_parallel_rank: | |
| self.consumed_samples += self._global_batch_size | |
| yield batch | |
| batch = [] | |
| # Check the last partial batch and see drop_last is set | |
| if len(batch) > 0 and not self.drop_last: | |
| yield batch | |