Skip to content

Commit 9c9aa38

Browse files
authored
Merge pull request #35 from curt-tigges/fix/repair-manifest
Fix/repair manifest
2 parents 311c9b8 + d2ede36 commit 9c9aa38

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

clt/training/data/local_activation_store.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,41 @@ def _load_manifest(self) -> Optional[np.ndarray]:
8989
logger.error(f"Manifest file not found: {path}")
9090
return None
9191
try:
92-
with open(path, "rb") as f:
93-
data = np.frombuffer(f.read(), dtype=np.uint32).reshape(-1, 2)
94-
logger.info(f"Manifest loaded from {path} ({len(data)} rows).")
92+
file_size_bytes = path.stat().st_size
93+
# Heuristic: older 2-field format is 8 bytes per entry (two uint32),
94+
# newer 3-field format is 16 bytes per entry (int32, int32, int64).
95+
if file_size_bytes % 16 == 0:
96+
# New format with 3 fields (chunk_id, num_tokens, offset)
97+
manifest_dtype = np.dtype([("chunk_id", np.int32), ("num_tokens", np.int32), ("offset", np.int64)])
98+
data_structured = np.fromfile(path, dtype=manifest_dtype)
99+
logger.info(
100+
f"Manifest loaded (3-field format) from {path} ({data_structured.shape[0]} chunks). Expanding to per-row entries."
101+
)
102+
# Expand into per-row entries expected by downstream (chunk_id, row_in_chunk)
103+
chunk_ids = data_structured["chunk_id"].astype(np.uint32)
104+
num_tokens_arr = data_structured["num_tokens"].astype(np.uint32)
105+
# Compute total rows
106+
total_rows = int(num_tokens_arr.sum())
107+
logger.info(f"Expanding manifest: total rows = {total_rows}")
108+
# Pre-allocate array
109+
data = np.empty((total_rows, 2), dtype=np.uint32)
110+
row_ptr = 0
111+
for cid, ntok in zip(chunk_ids, num_tokens_arr):
112+
data[row_ptr : row_ptr + ntok, 0] = cid # chunk_id column
113+
data[row_ptr : row_ptr + ntok, 1] = np.arange(ntok, dtype=np.uint32) # row index within chunk
114+
row_ptr += ntok
115+
elif file_size_bytes % 8 == 0:
116+
# Legacy 2-field format already matches expected shape
117+
data = np.fromfile(path, dtype=np.uint32).reshape(-1, 2)
118+
logger.info(f"Manifest loaded (legacy 2-field format) from {path} ({data.shape[0]} rows).")
119+
else:
120+
logger.error(
121+
f"Manifest file size ({file_size_bytes} bytes) is not compatible with known formats (8 or 16 bytes per row)."
122+
)
123+
return None
95124
return data
96125
except ValueError as e:
97-
logger.error(f"Error reshaping manifest data from {path} (expected Nx2): {e}")
126+
logger.error(f"Error parsing manifest data from {path}: {e}")
98127
return None
99128
except OSError as e:
100129
logger.error(f"Error reading manifest file {path}: {e}")
@@ -117,7 +146,7 @@ def _load_norm_stats(self) -> Optional[Dict[str, Any]]:
117146
logger.error(f"Error reading norm_stats file {path}: {e}")
118147
return None
119148

120-
@lru_cache(maxsize=256)
149+
@lru_cache(maxsize=64)
121150
def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str):
122151
"""Loads entire HDF5 chunk from disk and caches"""
123152

@@ -129,14 +158,29 @@ def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str):
129158
logger.error(f"Chunk file not found for fetch: {chunk_path}")
130159
raise
131160
except KeyError as e:
132-
raise RuntimeError(f"Missing 'inputs' or 'targets' dataset in layer group '{layer_key}' of chunk {chunk_path}") from e
161+
raise RuntimeError(
162+
f"Missing 'inputs' or 'targets' dataset in layer group '{layer_key}' of chunk {chunk_path}"
163+
) from e
133164
except Exception as e:
134165
logger.error(f"Failed to open chunk at {chunk_path}: {e}")
135166
raise RuntimeError(f"Failed to access chunk HDF5 file: {chunk_path}") from e
136167

137168
def _fetch_slice(self, chunk_id: int, row_indices: np.ndarray) -> bytes:
138169

139170
chunk_path = self.dataset_path / f"chunk_{chunk_id}.h5"
171+
if not chunk_path.exists():
172+
# Fall back to .hdf5 extension (newer generator default)
173+
alt_path = self.dataset_path / f"chunk_{chunk_id}.hdf5"
174+
if alt_path.exists():
175+
chunk_path = alt_path
176+
else:
177+
# Provide clearer error message before _open_h5 raises
178+
logger.error(
179+
"Chunk file for chunk_id %d not found with either .h5 or .hdf5 extension in %s",
180+
chunk_id,
181+
self.dataset_path,
182+
)
183+
140184
hf = _open_h5(chunk_path)
141185

142186
try:
@@ -162,8 +206,8 @@ def _layer_sort_key(name: str) -> int:
162206
row_indices_h5 = row_indices
163207

164208
for i, lk in enumerate(layer_keys):
165-
input_data = self._load_chunk(chunk_path, lk, 'inputs')[row_indices_h5, :]
166-
target_data = self._load_chunk(chunk_path, lk, 'targets')[row_indices_h5, :]
209+
input_data = self._load_chunk(chunk_path, lk, "inputs")[row_indices_h5, :]
210+
target_data = self._load_chunk(chunk_path, lk, "targets")[row_indices_h5, :]
167211
bufs.append(input_data.tobytes())
168212
bufs.append(target_data.tobytes())
169213
return b"".join(bufs)

clt/training/metric_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def log_training_step(
4949
self.metrics["train_losses"].append({"step": step, **loss_dict})
5050

5151
if not self.distributed or self.rank == 0:
52-
total_tokens_processed = self.training_config.train_batch_size_tokens * self.world_size * (step + 1)
52+
total_tokens_processed = self.training_config.train_batch_size_tokens * (step + 1)
5353

5454
self.wandb_logger.log_step(
5555
step,

0 commit comments

Comments
 (0)