|
|
import logging |
|
|
import os |
|
|
import random |
|
|
|
|
|
import pyarrow.feather as feather |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class CyclicalBatchDataset: |
|
|
""" |
|
|
Dataset class that loads saved batches from continuous generation script. |
|
|
Maintains a pointer and provides cyclical access to individual samples. |
|
|
Includes enhanced logging to track data shard cycling during training. |
|
|
Supports per-rank file sharding for large-scale distributed training. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
batches_dir: str, |
|
|
generator_type: str, |
|
|
device: torch.device | None = None, |
|
|
prefetch_next: bool = True, |
|
|
prefetch_threshold: int = 32, |
|
|
rank: int = 0, |
|
|
world_size: int = 1, |
|
|
): |
|
|
""" |
|
|
Initialize the cyclical batch dataset. |
|
|
|
|
|
Args: |
|
|
batches_dir: Directory containing the batch arrow files |
|
|
generator_type: Type of generator (for logging) |
|
|
device: Device to load tensors to |
|
|
prefetch_next: Whether to prefetch the next batch |
|
|
prefetch_threshold: Number of remaining samples to trigger prefetching |
|
|
rank: Rank of the current process (for file sharding) |
|
|
world_size: Total number of processes (for file sharding) |
|
|
""" |
|
|
self.batches_dir = batches_dir |
|
|
self.generator_type = generator_type |
|
|
self.device = device |
|
|
self.prefetch_next = prefetch_next |
|
|
self.prefetch_threshold = prefetch_threshold |
|
|
self.rank = rank |
|
|
self.world_size = world_size |
|
|
|
|
|
self.batch_files = self._find_batch_files() |
|
|
if not self.batch_files: |
|
|
raise ValueError(f"No batch files found in {batches_dir}") |
|
|
|
|
|
|
|
|
self.current_batch_idx = 0 |
|
|
self.current_sample_idx = 0 |
|
|
self.current_batch_data = None |
|
|
self.next_batch_data = None |
|
|
self.prefetching_in_progress = False |
|
|
|
|
|
|
|
|
self.visited_batch_indices = set() |
|
|
self.full_cycles_completed = 0 |
|
|
|
|
|
|
|
|
self._load_current_batch() |
|
|
self.visited_batch_indices.add(self.current_batch_idx) |
|
|
|
|
|
logger.info( |
|
|
f"Initialized '{self.generator_type}' dataset with {len(self.batch_files)} batches. " |
|
|
f"Current batch file: '{os.path.basename(self.batch_files[self.current_batch_idx])}' " |
|
|
f"has {len(self.current_batch_data)} samples." |
|
|
) |
|
|
|
|
|
def _find_batch_files(self) -> list[str]: |
|
|
""" |
|
|
Find and sort batch files with per-rank sharding for distributed training. |
|
|
|
|
|
Each rank gets a disjoint subset of files to minimize I/O contention |
|
|
when scaling to hundreds of GPUs. |
|
|
""" |
|
|
import glob |
|
|
|
|
|
pattern = os.path.join(self.batches_dir, "batch_*.arrow") |
|
|
all_files = sorted(glob.glob(pattern)) |
|
|
|
|
|
if not all_files: |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
rank_files = [f for i, f in enumerate(all_files) if i % self.world_size == self.rank] |
|
|
|
|
|
|
|
|
random.shuffle(rank_files) |
|
|
|
|
|
logger.info( |
|
|
f"[Rank {self.rank}] '{self.generator_type}': Sharded {len(all_files)} files β " |
|
|
f"{len(rank_files)} files for this rank ({len(rank_files) / len(all_files) * 100:.1f}%)" |
|
|
) |
|
|
|
|
|
return rank_files |
|
|
|
|
|
def _load_batch_from_file(self, batch_file: str) -> list[dict]: |
|
|
"""Load a batch from arrow file.""" |
|
|
try: |
|
|
table = feather.read_table(batch_file) |
|
|
has_num_channels = "num_channels" in table.column_names |
|
|
batch_data = [] |
|
|
for i in range(len(table)): |
|
|
row = { |
|
|
"series_id": table["series_id"][i].as_py(), |
|
|
"values": table["values"][i].as_py(), |
|
|
"length": table["length"][i].as_py(), |
|
|
"generator_type": table["generator_type"][i].as_py(), |
|
|
"start": table["start"][i].as_py(), |
|
|
"frequency": table["frequency"][i].as_py(), |
|
|
"generation_timestamp": table["generation_timestamp"][i].as_py(), |
|
|
} |
|
|
if has_num_channels: |
|
|
row["num_channels"] = table["num_channels"][i].as_py() |
|
|
else: |
|
|
row["num_channels"] = 1 |
|
|
batch_data.append(row) |
|
|
return batch_data |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading batch from {batch_file}: {e}") |
|
|
raise |
|
|
|
|
|
def _load_current_batch(self): |
|
|
"""Load the current batch into memory.""" |
|
|
if hasattr(self, "current_batch_data") and self.current_batch_data is not None: |
|
|
del self.current_batch_data |
|
|
batch_file = self.batch_files[self.current_batch_idx] |
|
|
self.current_batch_data = self._load_batch_from_file(batch_file) |
|
|
self.current_sample_idx = 0 |
|
|
logger.debug( |
|
|
f"Loaded batch {self.current_batch_idx} for {self.generator_type} " |
|
|
f"with {len(self.current_batch_data)} samples" |
|
|
) |
|
|
|
|
|
def _trigger_smart_prefetch(self): |
|
|
"""Trigger prefetching when batch is almost exhausted.""" |
|
|
if not self.prefetch_next or len(self.batch_files) <= 1: |
|
|
return |
|
|
remaining_samples = self.get_remaining_samples_in_current_batch() |
|
|
should_prefetch = ( |
|
|
remaining_samples <= self.prefetch_threshold |
|
|
and self.next_batch_data is None |
|
|
and not self.prefetching_in_progress |
|
|
) |
|
|
if should_prefetch: |
|
|
self._prefetch_next_batch() |
|
|
|
|
|
def _prefetch_next_batch(self): |
|
|
"""Prefetch the next batch.""" |
|
|
if self.prefetching_in_progress: |
|
|
return |
|
|
self.prefetching_in_progress = True |
|
|
next_batch_idx = (self.current_batch_idx + 1) % len(self.batch_files) |
|
|
next_batch_file = self.batch_files[next_batch_idx] |
|
|
try: |
|
|
self.next_batch_data = self._load_batch_from_file(next_batch_file) |
|
|
logger.debug(f"Prefetched next batch {next_batch_idx} for {self.generator_type}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to prefetch batch {next_batch_idx}: {e}") |
|
|
self.next_batch_data = None |
|
|
finally: |
|
|
self.prefetching_in_progress = False |
|
|
|
|
|
def _advance_to_next_batch(self): |
|
|
"""Advance to the next batch and log the transition.""" |
|
|
if hasattr(self, "current_batch_data") and self.current_batch_data is not None: |
|
|
del self.current_batch_data |
|
|
|
|
|
previous_batch_idx = self.current_batch_idx |
|
|
self.current_batch_idx = (self.current_batch_idx + 1) % len(self.batch_files) |
|
|
|
|
|
if hasattr(self, "next_batch_data") and self.next_batch_data is not None: |
|
|
self.current_batch_data = self.next_batch_data |
|
|
self.next_batch_data = None |
|
|
else: |
|
|
self._load_current_batch() |
|
|
|
|
|
self.current_sample_idx = 0 |
|
|
self.prefetching_in_progress = False |
|
|
|
|
|
|
|
|
self.visited_batch_indices.add(self.current_batch_idx) |
|
|
|
|
|
|
|
|
total_files = len(self.batch_files) |
|
|
visited_count = len(self.visited_batch_indices) |
|
|
progress_percent = (visited_count / total_files) * 100 |
|
|
|
|
|
|
|
|
logger.info( |
|
|
f"\nDATA SHARD CYCLED for '{self.generator_type}': " |
|
|
f"Moved from file index {previous_batch_idx} to {self.current_batch_idx}. " |
|
|
f"Unique files visited: {visited_count}/{total_files} ({progress_percent:.1f}%)." |
|
|
) |
|
|
|
|
|
|
|
|
if visited_count == total_files: |
|
|
self.full_cycles_completed += 1 |
|
|
logger.info( |
|
|
f"π FULL CYCLE #{self.full_cycles_completed} COMPLETED for '{self.generator_type}'! " |
|
|
f"All {total_files} data files have been visited at least once. " |
|
|
"Resetting visited set to track the next cycle." |
|
|
) |
|
|
|
|
|
self.visited_batch_indices.clear() |
|
|
self.visited_batch_indices.add(self.current_batch_idx) |
|
|
|
|
|
def get_sample(self) -> dict: |
|
|
"""Get the current sample and advance pointer.""" |
|
|
if not hasattr(self, "current_batch_data") or self.current_batch_data is None: |
|
|
self._load_current_batch() |
|
|
if self.current_batch_data is None: |
|
|
raise RuntimeError("No batch data loaded") |
|
|
if self.current_sample_idx >= len(self.current_batch_data): |
|
|
self._advance_to_next_batch() |
|
|
self._trigger_smart_prefetch() |
|
|
sample = self.current_batch_data[self.current_sample_idx] |
|
|
self.current_sample_idx += 1 |
|
|
return sample |
|
|
|
|
|
def get_samples(self, num_samples: int) -> list[dict]: |
|
|
"""Get multiple samples.""" |
|
|
samples = [] |
|
|
for _ in range(num_samples): |
|
|
samples.append(self.get_sample()) |
|
|
return samples |
|
|
|
|
|
def get_total_samples_in_current_batch(self) -> int: |
|
|
"""Get total samples in current batch.""" |
|
|
if not hasattr(self, "current_batch_data") or self.current_batch_data is None: |
|
|
return 0 |
|
|
return len(self.current_batch_data) |
|
|
|
|
|
def get_remaining_samples_in_current_batch(self) -> int: |
|
|
"""Get remaining samples in current batch.""" |
|
|
if not hasattr(self, "current_batch_data") or self.current_batch_data is None: |
|
|
return 0 |
|
|
return max(0, len(self.current_batch_data) - self.current_sample_idx) |
|
|
|
|
|
def get_info(self) -> dict: |
|
|
"""Get extended dataset info, including cycle progress.""" |
|
|
total_files = len(self.batch_files) |
|
|
visited_count = len(self.visited_batch_indices) |
|
|
return { |
|
|
"generator_type": self.generator_type, |
|
|
"total_batch_files": total_files, |
|
|
"current_batch_idx": self.current_batch_idx, |
|
|
"current_sample_idx": self.current_sample_idx, |
|
|
"current_batch_size": self.get_total_samples_in_current_batch(), |
|
|
"remaining_in_batch": self.get_remaining_samples_in_current_batch(), |
|
|
"unique_files_visited": visited_count, |
|
|
"cycle_progress_percent": (visited_count / total_files) * 100 if total_files > 0 else 0, |
|
|
"full_cycles_completed": self.full_cycles_completed, |
|
|
} |
|
|
|