File size: 10,651 Bytes
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
96e1a32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import logging
import os
import random

import pyarrow.feather as feather
import torch

logger = logging.getLogger(__name__)


class CyclicalBatchDataset:
    """
    Dataset class that loads saved batches from continuous generation script.
    Maintains a pointer and provides cyclical access to individual samples.
    Includes enhanced logging to track data shard cycling during training.
    Supports per-rank file sharding for large-scale distributed training.
    """

    def __init__(
        self,
        batches_dir: str,
        generator_type: str,
        device: torch.device | None = None,
        prefetch_next: bool = True,
        prefetch_threshold: int = 32,
        rank: int = 0,
        world_size: int = 1,
    ):
        """
        Initialize the cyclical batch dataset.

        Args:
            batches_dir: Directory containing the batch arrow files
            generator_type: Type of generator (for logging)
            device: Device to load tensors to
            prefetch_next: Whether to prefetch the next batch
            prefetch_threshold: Number of remaining samples to trigger prefetching
            rank: Rank of the current process (for file sharding)
            world_size: Total number of processes (for file sharding)
        """
        self.batches_dir = batches_dir
        self.generator_type = generator_type
        self.device = device
        self.prefetch_next = prefetch_next
        self.prefetch_threshold = prefetch_threshold
        self.rank = rank
        self.world_size = world_size

        self.batch_files = self._find_batch_files()
        if not self.batch_files:
            raise ValueError(f"No batch files found in {batches_dir}")

        # --- State tracking ---
        self.current_batch_idx = 0
        self.current_sample_idx = 0
        self.current_batch_data = None
        self.next_batch_data = None
        self.prefetching_in_progress = False

        # --- NEW: Logging and cycle tracking ---
        self.visited_batch_indices = set()
        self.full_cycles_completed = 0

        # Load first batch and update tracking
        self._load_current_batch()
        self.visited_batch_indices.add(self.current_batch_idx)

        logger.info(
            f"Initialized '{self.generator_type}' dataset with {len(self.batch_files)} batches. "
            f"Current batch file: '{os.path.basename(self.batch_files[self.current_batch_idx])}' "
            f"has {len(self.current_batch_data)} samples."
        )

    def _find_batch_files(self) -> list[str]:
        """
        Find and sort batch files with per-rank sharding for distributed training.

        Each rank gets a disjoint subset of files to minimize I/O contention
        when scaling to hundreds of GPUs.
        """
        import glob

        pattern = os.path.join(self.batches_dir, "batch_*.arrow")
        all_files = sorted(glob.glob(pattern))  # Sort for deterministic sharding

        if not all_files:
            return []

        # Shard files across ranks: each rank gets every world_size-th file
        # Example with 4 ranks: rank0=[0,4,8,...], rank1=[1,5,9,...], etc.
        rank_files = [f for i, f in enumerate(all_files) if i % self.world_size == self.rank]

        # Shuffle only within this rank's shard for variety
        random.shuffle(rank_files)

        logger.info(
            f"[Rank {self.rank}] '{self.generator_type}': Sharded {len(all_files)} files β†’ "
            f"{len(rank_files)} files for this rank ({len(rank_files) / len(all_files) * 100:.1f}%)"
        )

        return rank_files

    def _load_batch_from_file(self, batch_file: str) -> list[dict]:
        """Load a batch from arrow file."""
        try:
            table = feather.read_table(batch_file)
            has_num_channels = "num_channels" in table.column_names
            batch_data = []
            for i in range(len(table)):
                row = {
                    "series_id": table["series_id"][i].as_py(),
                    "values": table["values"][i].as_py(),
                    "length": table["length"][i].as_py(),
                    "generator_type": table["generator_type"][i].as_py(),
                    "start": table["start"][i].as_py(),
                    "frequency": table["frequency"][i].as_py(),
                    "generation_timestamp": table["generation_timestamp"][i].as_py(),
                }
                if has_num_channels:
                    row["num_channels"] = table["num_channels"][i].as_py()
                else:
                    row["num_channels"] = 1
                batch_data.append(row)
            return batch_data
        except Exception as e:
            logger.error(f"Error loading batch from {batch_file}: {e}")
            raise

    def _load_current_batch(self):
        """Load the current batch into memory."""
        if hasattr(self, "current_batch_data") and self.current_batch_data is not None:
            del self.current_batch_data
        batch_file = self.batch_files[self.current_batch_idx]
        self.current_batch_data = self._load_batch_from_file(batch_file)
        self.current_sample_idx = 0
        logger.debug(
            f"Loaded batch {self.current_batch_idx} for {self.generator_type} "
            f"with {len(self.current_batch_data)} samples"
        )

    def _trigger_smart_prefetch(self):
        """Trigger prefetching when batch is almost exhausted."""
        if not self.prefetch_next or len(self.batch_files) <= 1:
            return
        remaining_samples = self.get_remaining_samples_in_current_batch()
        should_prefetch = (
            remaining_samples <= self.prefetch_threshold
            and self.next_batch_data is None
            and not self.prefetching_in_progress
        )
        if should_prefetch:
            self._prefetch_next_batch()

    def _prefetch_next_batch(self):
        """Prefetch the next batch."""
        if self.prefetching_in_progress:
            return
        self.prefetching_in_progress = True
        next_batch_idx = (self.current_batch_idx + 1) % len(self.batch_files)
        next_batch_file = self.batch_files[next_batch_idx]
        try:
            self.next_batch_data = self._load_batch_from_file(next_batch_file)
            logger.debug(f"Prefetched next batch {next_batch_idx} for {self.generator_type}")
        except Exception as e:
            logger.warning(f"Failed to prefetch batch {next_batch_idx}: {e}")
            self.next_batch_data = None
        finally:
            self.prefetching_in_progress = False

    def _advance_to_next_batch(self):
        """Advance to the next batch and log the transition."""
        if hasattr(self, "current_batch_data") and self.current_batch_data is not None:
            del self.current_batch_data

        previous_batch_idx = self.current_batch_idx
        self.current_batch_idx = (self.current_batch_idx + 1) % len(self.batch_files)

        if hasattr(self, "next_batch_data") and self.next_batch_data is not None:
            self.current_batch_data = self.next_batch_data
            self.next_batch_data = None
        else:
            self._load_current_batch()

        self.current_sample_idx = 0
        self.prefetching_in_progress = False

        # --- NEW: Enhanced Logging Logic ---
        self.visited_batch_indices.add(self.current_batch_idx)

        # Calculate progress
        total_files = len(self.batch_files)
        visited_count = len(self.visited_batch_indices)
        progress_percent = (visited_count / total_files) * 100

        # Log the shard cycle event
        logger.info(
            f"\nDATA SHARD CYCLED for '{self.generator_type}': "
            f"Moved from file index {previous_batch_idx} to {self.current_batch_idx}. "
            f"Unique files visited: {visited_count}/{total_files} ({progress_percent:.1f}%)."
        )

        # Check if a full cycle has been completed
        if visited_count == total_files:
            self.full_cycles_completed += 1
            logger.info(
                f"πŸŽ‰ FULL CYCLE #{self.full_cycles_completed} COMPLETED for '{self.generator_type}'! "
                f"All {total_files} data files have been visited at least once. "
                "Resetting visited set to track the next cycle."
            )
            # Reset for the next cycle count
            self.visited_batch_indices.clear()
            self.visited_batch_indices.add(self.current_batch_idx)

    def get_sample(self) -> dict:
        """Get the current sample and advance pointer."""
        if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
            self._load_current_batch()
        if self.current_batch_data is None:
            raise RuntimeError("No batch data loaded")
        if self.current_sample_idx >= len(self.current_batch_data):
            self._advance_to_next_batch()
        self._trigger_smart_prefetch()
        sample = self.current_batch_data[self.current_sample_idx]
        self.current_sample_idx += 1
        return sample

    def get_samples(self, num_samples: int) -> list[dict]:
        """Get multiple samples."""
        samples = []
        for _ in range(num_samples):
            samples.append(self.get_sample())
        return samples

    def get_total_samples_in_current_batch(self) -> int:
        """Get total samples in current batch."""
        if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
            return 0
        return len(self.current_batch_data)

    def get_remaining_samples_in_current_batch(self) -> int:
        """Get remaining samples in current batch."""
        if not hasattr(self, "current_batch_data") or self.current_batch_data is None:
            return 0
        return max(0, len(self.current_batch_data) - self.current_sample_idx)

    def get_info(self) -> dict:
        """Get extended dataset info, including cycle progress."""
        total_files = len(self.batch_files)
        visited_count = len(self.visited_batch_indices)
        return {
            "generator_type": self.generator_type,
            "total_batch_files": total_files,
            "current_batch_idx": self.current_batch_idx,
            "current_sample_idx": self.current_sample_idx,
            "current_batch_size": self.get_total_samples_in_current_batch(),
            "remaining_in_batch": self.get_remaining_samples_in_current_batch(),
            "unique_files_visited": visited_count,
            "cycle_progress_percent": (visited_count / total_files) * 100 if total_files > 0 else 0,
            "full_cycles_completed": self.full_cycles_completed,
        }