File size: 10,651 Bytes
1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 1c8d125 96e1a32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
import logging
import os
import random
import pyarrow.feather as feather
import torch
logger = logging.getLogger(__name__)
class CyclicalBatchDataset:
"""
Dataset class that loads saved batches from continuous generation script.
Maintains a pointer and provides cyclical access to individual samples.
Includes enhanced logging to track data shard cycling during training.
Supports per-rank file sharding for large-scale distributed training.
"""
def __init__(
self,
batches_dir: str,
generator_type: str,
device: torch.device | None = None,
prefetch_next: bool = True,
prefetch_threshold: int = 32,
rank: int = 0,
world_size: int = 1,
):
"""
Initialize the cyclical batch dataset.
Args:
batches_dir: Directory containing the batch arrow files
generator_type: Type of generator (for logging)
device: Device to load tensors to
prefetch_next: Whether to prefetch the next batch
prefetch_threshold: Number of remaining samples to trigger prefetching
rank: Rank of the current process (for file sharding)
world_size: Total number of processes (for file sharding)
"""
self.batches_dir = batches_dir
self.generator_type = generator_type
self.device = device
self.prefetch_next = prefetch_next
self.prefetch_threshold = prefetch_threshold
self.rank = rank
self.world_size = world_size
self.batch_files = self._find_batch_files()
if not self.batch_files:
raise ValueError(f"No batch files found in {batches_dir}")
# --- State tracking ---
self.current_batch_idx = 0
self.current_sample_idx = 0
self.current_batch_data = None
self.next_batch_data = None
self.prefetching_in_progress = False
# --- NEW: Logging and cycle tracking ---
self.visited_batch_indices = set()
self.full_cycles_completed = 0
# Load first batch and update tracking
self._load_current_batch()
self.visited_batch_indices.add(self.current_batch_idx)
logger.info(
f"Initialized '{self.generator_type}' dataset with {len(self.batch_files)} batches. "
f"Current batch file: '{os.path.basename(self.batch_files[self.current_batch_idx])}' "
f"has {len(self.current_batch_data)} samples."
)
def _find_batch_files(self) -> list[str]:
"""
Find and sort batch files with per-rank sharding for distributed training.
Each rank gets a disjoint subset of files to minimize I/O contention
when scaling to hundreds of GPUs.
"""
import glob
pattern = os.path.join(self.batches_dir, "batch_*.arrow")
all_files = sorted(glob.glob(pattern)) # Sort for deterministic sharding
if not all_files:
return []
# Shard files across ranks: each rank gets every world_size-th file
# Example with 4 ranks: rank0=[0,4,8,...], rank1=[1,5,9,...], etc.
rank_files = [f for i, f in enumerate(all_files) if i % self.world_size == self.rank]
# Shuffle only within this rank's shard for variety
random.shuffle(rank_files)
logger.info(
f"[Rank {self.rank}] '{self.generator_type}': Sharded {len(all_files)} files β "
f"{len(rank_files)} files for this rank ({len(rank_files) / len(all_files) * 100:.1f}%)"
)
return rank_files
def _load_batch_from_file(self, batch_file: str) -> list[dict]:
"""Load a batch from arrow file."""
try:
table = feather.read_table(batch_file)
has_num_channels = "num_channels" in table.column_names
batch_data = []
for i in range(len(table)):
row = {
"series_id": table["series_id"][i].as_py(),
"values": table["values"][i].as_py(),
"length": table["length"][i].as_py(),
"generator_type": table["generator_type"][i].as_py(),
"start": table["start"][i].as_py(),
"frequency": table["frequency"][i].as_py(),
"generation_timestamp": table["generation_timestamp"][i].as_py(),
}
if has_num_channels:
row["num_channels"] = table["num_channels"][i].as_py()
else:
row["num_channels"] = 1
batch_data.append(row)
return batch_data
except Exception as e:
logger.error(f"Error loading batch from {batch_file}: {e}")
raise
def _load_current_batch(self):
"""Load the current batch into memory."""
if hasattr(self, "current_batch_data") and self.current_batch_data is not None:
del self.current_batch_data
batch_file = self.batch_files[self.current_batch_idx]
self.current_batch_data = self._load_batch_from_file(batch_file)
self.current_sample_idx = 0
logger.debug(
f"Loaded batch {self.current_batch_idx} for {self.generator_type} "
f"with {len(self.current_batch_data)} samples"
)
def _trigger_smart_prefetch(self):
"""Trigger prefetching when batch is almost exhausted."""
if not self.prefetch_next or len(self.batch_files) <= 1:
return
remaining_samples = self.get_remaining_samples_in_current_batch()
should_prefetch = (
remaining_samples <= self.prefetch_threshold
and self.next_batch_data is None
and not self.prefetching_in_progress
)
if should_prefetch:
self._prefetch_next_batch()
def _prefetch_next_batch(self):
"""Prefetch the next batch."""
if self.prefetching_in_progress:
return
self.prefetching_in_progress = True
next_batch_idx = (self.current_batch_idx + 1) % len(self.batch_files)
next_batch_file = self.batch_files[next_batch_idx]
try:
self.next_batch_data = self._load_batch_from_file(next_batch_file)
logger.debug(f"Prefetched next batch {next_batch_idx} for {self.generator_type}")
except Exception as e:
logger.warning(f"Failed to prefetch batch {next_batch_idx}: {e}")
self.next_batch_data = None
finally:
self.prefetching_in_progress = False
def _advance_to_next_batch(self):
"""Advance to the next batch and log the transition."""
if hasattr(self, "current_batch_data") and self.current_batch_data is not None:
del self.current_batch_data
previous_batch_idx = self.current_batch_idx
self.current_batch_idx = (self.current_batch_idx + 1) % len(self.batch_files)
if hasattr(self, "next_batch_data") and self.next_batch_data is not None:
self.current_batch_data = self.next_batch_data
self.next_batch_data = None
else:
self._load_current_batch()
self.current_sample_idx = 0
self.prefetching_in_progress = False
# --- NEW: Enhanced Logging Logic ---
self.visited_batch_indices.add(self.current_batch_idx)
# Calculate progress
total_files = len(self.batch_files)
visited_count = len(self.visited_batch_indices)
progress_percent = (visited_count / total_files) * 100
# Log the shard cycle event
logger.info(
f"\nDATA SHARD CYCLED for '{self.generator_type}': "
f"Moved from file index {previous_batch_idx} to {self.current_batch_idx}. "
f"Unique files visited: {visited_count}/{total_files} ({progress_percent:.1f}%)."
)
# Check if a full cycle has been completed
if visited_count == total_files:
self.full_cycles_completed += 1
logger.info(
f"π FULL CYCLE #{self.full_cycles_completed} COMPLETED for '{self.generator_type}'! "
f"All {total_files} data files have been visited at least once. "
"Resetting visited set to track the next cycle."
)
# Reset for the next cycle count
self.visited_batch_indices.clear()
self.visited_batch_indices.add(self.current_batch_idx)
def get_sample(self) -> dict:
"""Get the current sample and advance pointer."""
if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
self._load_current_batch()
if self.current_batch_data is None:
raise RuntimeError("No batch data loaded")
if self.current_sample_idx >= len(self.current_batch_data):
self._advance_to_next_batch()
self._trigger_smart_prefetch()
sample = self.current_batch_data[self.current_sample_idx]
self.current_sample_idx += 1
return sample
def get_samples(self, num_samples: int) -> list[dict]:
"""Get multiple samples."""
samples = []
for _ in range(num_samples):
samples.append(self.get_sample())
return samples
def get_total_samples_in_current_batch(self) -> int:
"""Get total samples in current batch."""
if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
return 0
return len(self.current_batch_data)
def get_remaining_samples_in_current_batch(self) -> int:
"""Get remaining samples in current batch."""
if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
return 0
return max(0, len(self.current_batch_data) - self.current_sample_idx)
def get_info(self) -> dict:
"""Get extended dataset info, including cycle progress."""
total_files = len(self.batch_files)
visited_count = len(self.visited_batch_indices)
return {
"generator_type": self.generator_type,
"total_batch_files": total_files,
"current_batch_idx": self.current_batch_idx,
"current_sample_idx": self.current_sample_idx,
"current_batch_size": self.get_total_samples_in_current_batch(),
"remaining_in_batch": self.get_remaining_samples_in_current_batch(),
"unique_files_visited": visited_count,
"cycle_progress_percent": (visited_count / total_files) * 100 if total_files > 0 else 0,
"full_cycles_completed": self.full_cycles_completed,
}
|