subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataloaders."""
import abc
import warnings
from itertools import chain
from typing import Optional, Tuple
import torch
from nemo.utils import logging
from nemo.utils.decorators import experimental
__all__ = [
"MegatronPretrainingBatchSampler",
"MegatronPretrainingRandomBatchSampler",
]
class BaseMegatronSampler:
""" """
def __init__(
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
global_batch_size: Optional[int] = None,
rampup_batch_size: Optional[list] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
) -> None:
# Sanity checks.
if total_samples <= 0:
raise RuntimeError("no sample to consume: {}".format(total_samples))
if micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError(
"data_parallel_rank should be smaller than data size, but {} >= {}".format(
data_parallel_rank, data_parallel_size
)
)
if global_batch_size is not None and rampup_batch_size is None:
if global_batch_size % (micro_batch_size * data_parallel_size) != 0:
raise RuntimeError(
f"`global_batch_size` ({global_batch_size}) is not divisible by "
f"`micro_batch_size ({micro_batch_size}) x data_parallel_size "
f"({data_parallel_size})`"
)
if pad_samples_to_global_batch_size and global_batch_size is None:
raise RuntimeError(
"`pad_samples_to_global_batch_size` can be `True` only when "
"`global_batch_size` is set to an integer value"
)
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
self.global_batch_size = global_batch_size
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
logging.info(
f'Instantiating MegatronPretrainingSampler with total_samples: {total_samples} '
f'and consumed_samples: {consumed_samples}'
)
def __len__(self):
num_available_samples: int = self.total_samples - self.consumed_samples
if self.global_batch_size is not None:
if self.drop_last:
num_global_batches = num_available_samples // self.global_batch_size
else:
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
# num of batches fetched (as training step fetches in terms of micro batches)
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1
@abc.abstractmethod
def __iter__(self): ...
class MegatronPretrainingSampler(BaseMegatronSampler):
""" """
def get_start_end_idx(self):
""" """
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def _get_padding_indices(self, pad_samples_num):
""" """
return range(-1, -pad_samples_num - 1, -1)
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
indices = range(self.consumed_samples, self.total_samples)
if (not self.drop_last) and self.pad_samples_to_global_batch_size:
pad_samples_num = -len(indices) % self.global_batch_size
pad_indices = self._get_padding_indices(pad_samples_num)
indices = chain(indices, pad_indices)
for idx in indices:
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
assert (
not self.pad_samples_to_global_batch_size
), 'with pad_samples_to_global_batch_size all batches should be complete'
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronCorePretrainingSampler(MegatronPretrainingSampler):
""" """
def _get_padding_indices(self, pad_samples_num):
""" """
return [None] * pad_samples_num
class MegatronPretrainingRandomSampler(BaseMegatronSampler):
""" """
def __init__(
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
global_batch_size: Optional[int] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
seed: int = 0,
) -> None:
super().__init__(
total_samples=total_samples,
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
global_batch_size=global_batch_size,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
assert (
not pad_samples_to_global_batch_size
), "`MegatronPretrainingRandomSampler` does not support sample padding"
if (not drop_last) and self.micro_batch_times_data_parallel_size > 1:
raise RuntimeError(
"`MegatronPretrainingRandomSampler` does not support drop_last=False when \
micro_batch_size * data_parallel_size > 1. Please reduce your MBS and data parallelism to 1 \
if you want to use drop_last=False, or switch to drop_last=True to avoid this error"
)
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
self.seed = seed
def __len__(self):
active_total_samples = self.total_samples - (self.last_batch_size if self.drop_last else 0)
num_available_samples = active_total_samples - self.consumed_samples % active_total_samples
if self.global_batch_size is not None:
if self.drop_last:
num_global_batches = num_available_samples // self.global_batch_size
else:
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
# num of batches fetched (as training step fetches in terms of micro batches)
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
else:
if self.drop_last:
return num_available_samples // self.micro_batch_times_data_parallel_size
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
yield batch
class BaseMegatronBatchSampler:
"""Megatron style BatchSampler.
Let mbs, gbs, tp, pp, and dp stand for "micro batch size", "global batch size",
"tensor model parallel world size", "pipeline model parallel world size", and
"data parallel world size", the number of micro batches (hereafter, nmb) is defined as
:math:`nmb = gbs \\div (mbs \\times dp)`.
See `apex/transformer/microbatches.py#L91-L98 <https://github.com/NVIDIA/apex/blob/
44c3043685b6115e7b81b3458a6c76601b1e55b4/apex/transformer/microbatches.py#L91-L98>`_
for the initial settings of the number of micro batches and
`apex/transformer/microbatches.py#L160-L177 <https://github.com/NVIDIA/apex/blob/
44c3043685b6115e7b81b3458a6c76601b1e55b4/apex/transformer/microbatches.py#L160-L177>_`.
for warming up of global batch size.
e.g.) `(mbs, gbs, tp, pp, dp) = (1, 16, 1, 1, 2)`, then the number of micro batches is
:math:`gbs \\div (mbs \\times dp) = 16 \\div (1 \\times 2) = 8`.
In this case, an instance of Megatron Batch Sampler on each data parallel rank is expected
returns :math:`nmb \\times mbs = 8` indices.
"""
_global_batch_size: int
_num_micro_batches: int
_global_batch_size_on_this_data_parallel_rank: int
def __init__(
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
global_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool,
pad_samples_to_global_batch_size=False,
) -> None:
"""Constructor of Megatron-LM style Batch Sampler.
Args:
total_samples: The size of dataset.
consumed_samples: The number of samples that have been used.
micro_batch_size: The size of each micro batch.
global_batch_size: The size of global batch.
data_parallel_rank: The value you can obtain via
`parallel_state.get_data_parallel_rank()` of megatron.core.
data_parallel_size: The value you can obtain via
`parallel_state.get_data_parallel_world_size()` of megatron.core.
"""
# Sanity checks.
if total_samples <= 0:
raise RuntimeError("no sample to consume: {}".format(total_samples))
if micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError(
"data_parallel_rank should be smaller than data size, but {} >= {}".format(
data_parallel_rank, data_parallel_size
)
)
# Keep a copy of input params for later use.
self.total_samples: int = total_samples
self.consumed_samples: int = consumed_samples
self.micro_batch_size: int = micro_batch_size
self.data_parallel_rank: int = data_parallel_rank
self.data_parallel_size: int = data_parallel_size
self.drop_last: bool = drop_last
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size
self.update_global_batch_size(global_batch_size)
def update_global_batch_size(self, new_global_batch_size: int) -> None:
"""Update the global batch size."""
self._global_batch_size = new_global_batch_size
if self._global_batch_size % self.micro_batch_times_data_parallel_size != 0:
raise RuntimeError(
f"`global_batch_size` ({self._global_batch_size}) is not divisible by "
f"`micro_batch_size ({self.micro_batch_size}) x data_parallel_size "
f"({self.data_parallel_size})`"
)
self._num_micro_batches = self._global_batch_size // self.micro_batch_times_data_parallel_size
self._global_batch_size_on_this_data_parallel_rank = self._num_micro_batches * self.micro_batch_size
@property
def global_batch_size(self) -> int:
""" """
return self._global_batch_size
@global_batch_size.setter
def global_batch_size(self, new_global_batch_size: int) -> None:
""" """
warnings.warn("`self.update_global_batch_size(new_global_batch_size)` is recommended.")
self.update_global_batch_size(new_global_batch_size=new_global_batch_size)
def __len__(self) -> int:
"""Length of Batch Sampler.
..note::
When `rampup_batch_size` is enabled, the return value can be not exactly precise.
"""
num_available_samples: int = self.total_samples - self.consumed_samples % self.total_samples
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
@abc.abstractmethod
def __iter__(self): ...
class MegatronPretrainingBatchSampler(BaseMegatronBatchSampler):
""" """
def get_start_end_idx(self) -> Tuple[int, int]:
""" """
start_idx = self.data_parallel_rank * self._global_batch_size_on_this_data_parallel_rank
end_idx = start_idx + self._global_batch_size_on_this_data_parallel_rank
return start_idx, end_idx
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples % self.total_samples, self.total_samples):
batch.append(idx)
if len(batch) == self._global_batch_size:
# start_idx, end_idx = self.get_start_end_idx()
indices = [
batch[i]
for i in range(
self.data_parallel_rank,
self._global_batch_size,
self.data_parallel_size,
)
]
assert len(indices) == self._global_batch_size_on_this_data_parallel_rank
yield indices
# yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
# start_idx, end_idx = self.get_start_end_idx()
indices = [batch[i] for i in range(self.data_parallel_rank, len(batch), self.data_parallel_size)]
if self.pad_samples_to_global_batch_size:
num_pad = self._global_batch_size // self.data_parallel_size - len(indices)
indices = indices + [-1] * num_pad
yield indices
@experimental
class MegatronPretrainingRandomBatchSampler(BaseMegatronBatchSampler):
""" """
# NOTE (mkozuki): [[Argument of `dataset` and `data_sharding`]]
# From the commit below, it seems like `dataset` argument and `data_sharding` argument
# are necessary for ViT training. However, to keep this simple,
# I omit those two arguments.
# commit: https://github.com/NVIDIA/Megatron-LM/commit/7a77abd9b6267dc0020a60b424b4748fc22790bb
#
# NOTE (degert): I have re-written this class somewhat to give the length correctly when consumed_samples
# are larger than total_samples, which happens with epochs > 1 training when using this Sampler
# I have also added an explicit seed which allows us to remove Dataset-side shuffling in Nemo-Aligner
#
# This class does not currently work with pad_samples_to_global_batch_size=True
def __init__(
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
global_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool,
pad_samples_to_global_batch_size: bool = False,
seed: int = 0,
) -> None:
super().__init__(
total_samples=total_samples,
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
global_batch_size=global_batch_size,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
assert (
not pad_samples_to_global_batch_size
), "`MegatronPretrainingRandomBatchSampler` does not support sample padding"
if (not drop_last) and self.micro_batch_times_data_parallel_size > 1:
raise RuntimeError(
"`MegatronPretrainingRandomBatchSampler` does not support drop_last=False \
when micro_batch_size * data_parallel_size > 1. Please reduce your MBS and data parallelism to 1 \
if you want to use drop_last=False, or switch to drop_last=True to avoid this error"
)
self.last_batch_size = self.total_samples % self._global_batch_size
self.seed = seed
def __len__(self) -> int:
"""Length of Random Batch Sampler.
..note::
When `rampup_batch_size` is enabled, the return value can be not exactly precise.
"""
active_total_samples = self.total_samples - (self.last_batch_size if self.drop_last else 0)
num_available_samples = active_total_samples - self.consumed_samples % active_total_samples
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self._global_batch_size_on_this_data_parallel_rank:
self.consumed_samples += self._global_batch_size
yield batch
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
yield batch