subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blendable dataset."""
import os
import subprocess
import time
import numpy as np
import torch
from nemo.utils import logging
from nemo.utils.app_state import AppState
class BlendableDataset(torch.utils.data.Dataset):
""" """
def __init__(self, datasets, weights, size):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
self.size = size
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
# Build indecies.
start_time = time.time()
assert num_datasets < 255
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
app_state = AppState()
try:
if app_state.local_rank == 0:
compile_helper()
torch.distributed.barrier()
from nemo.collections.common.data import helpers
except ImportError:
raise ImportError(
'Could not compile megatron dataset C++ helper functions and therefore '
'cannot import helpers python file.'
)
helpers.build_blending_indices(
self.dataset_index,
self.dataset_sample_index,
weights,
num_datasets,
self.size,
torch.distributed.get_rank() == 0,
)
logging.info(
'> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time)
)
def __len__(self):
return self.size
def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
dataset_size = len(self.datasets[dataset_idx])
# Ensure the sample index doesn't exceed the dataset size
if sample_idx >= dataset_size:
logging.warning(f"Index {sample_idx} out of bounds for dataset {dataset_idx}. Reusing existing examples.")
sample_idx = sample_idx % dataset_size
logging.warning(f"Reusing index {sample_idx} for dataset {dataset_idx}.")
return self.datasets[dataset_idx][sample_idx]
def create_data_mmap(self):
""" """
for dataset in self.datasets:
dataset.create_data_mmap()
class MemoryEfficientBlendableDataset(torch.utils.data.Dataset):
"""
A BlendableDataset implementation that uses less memory than the original implementation.
Indices are computed algorithmically instead of storing them in memory.
To test call: MemoryEfficientBlendableDataset.test_index_blending()
"""
def __init__(self, datasets, weights, size, weight_bins=100):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
weight_bins = min(weight_bins, size)
self.size = size
self.weight_bins = weight_bins
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
assert (weights > 0.0).all()
sum_weights = np.sum(weights)
assert sum_weights > 0.0
self.weights = weights / sum_weights
# create ds index based on weights
ds_index = []
ds_bias = []
for i, w in enumerate(self.weights):
n = int(w * weight_bins)
ds_index.extend([i] * n)
ds_bias.extend(range(n))
# make sure arrays have length of weight_bins
n = weight_bins - len(ds_index)
ds_index.extend([i] * n)
ds_bias.extend(range(ds_bias[-1], ds_bias[-1] + n))
self.ds_index = np.array(ds_index, dtype=np.uint32)
self.ds_index_size = np.array([(self.ds_index == i).sum() for i in range(num_datasets)], dtype=np.uint32)
assert (self.ds_index_size > 0).all(), (
"Some datasets have no samples in the blendable dataset, "
"increase weight_bins or the offending weight. "
f"ds_index_size = {self.ds_index_size}"
)
self.ds_bias = np.array(ds_bias, dtype=np.uint32)
self.ds_size = np.array([len(ds) for ds in datasets], dtype=np.uint32)
def get_ds_sample_idx(self, idx):
"""Returns ds index and sample index (within the ds) for the given index in the blendable dataset."""
bin = idx % self.weight_bins
ds_idx = self.ds_index[bin]
sample_idx = (self.ds_bias[bin] + (idx // self.weight_bins) * self.ds_index_size[ds_idx]) % self.ds_size[
ds_idx
]
return ds_idx, sample_idx
def __len__(self):
return self.size
def __getitem__(self, idx):
ds_idx, sample_idx = self.get_ds_sample_idx(idx)
return self.datasets[ds_idx][sample_idx]
@classmethod
def test_index_blending(cls):
"""Visualize indices of blended dataset"""
import matplotlib.pyplot as plt
plt.ion()
class DS(torch.utils.data.Dataset):
""" """
def __init__(self, size, data):
self.size = size
self.data = data
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx]
for weight_bins in [10, 100]:
blend_ds = MemoryEfficientBlendableDataset(
[DS(10, "a"), DS(10, "b"), DS(10, "c")], [0.5, 0.3, 0.2], 50, weight_bins=weight_bins
)
ds_sample_idx_list = [blend_ds.get_ds_sample_idx(i) for i in range(50)]
ds_list = list(zip(*ds_sample_idx_list))[0]
sample_list = list(zip(*ds_sample_idx_list))[1]
plt.figure()
plt.plot(ds_list, label="ds idx")
plt.plot(sample_list, label="sample")
plt.legend()
plt.grid()
plt.title(f"weight_bins={weight_bins}")
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(['make', '-C', path])
if ret.returncode != 0:
logging.error("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)