Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions clt/training/data/local_activation_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,41 @@ def _load_manifest(self) -> Optional[np.ndarray]:
logger.error(f"Manifest file not found: {path}")
return None
try:
with open(path, "rb") as f:
data = np.frombuffer(f.read(), dtype=np.uint32).reshape(-1, 2)
logger.info(f"Manifest loaded from {path} ({len(data)} rows).")
file_size_bytes = path.stat().st_size
# Heuristic: older 2-field format is 8 bytes per entry (two uint32),
# newer 3-field format is 16 bytes per entry (int32, int32, int64).
if file_size_bytes % 16 == 0:
# New format with 3 fields (chunk_id, num_tokens, offset)
manifest_dtype = np.dtype([("chunk_id", np.int32), ("num_tokens", np.int32), ("offset", np.int64)])
data_structured = np.fromfile(path, dtype=manifest_dtype)
logger.info(
f"Manifest loaded (3-field format) from {path} ({data_structured.shape[0]} chunks). Expanding to per-row entries."
)
# Expand into per-row entries expected by downstream (chunk_id, row_in_chunk)
chunk_ids = data_structured["chunk_id"].astype(np.uint32)
num_tokens_arr = data_structured["num_tokens"].astype(np.uint32)
# Compute total rows
total_rows = int(num_tokens_arr.sum())
logger.info(f"Expanding manifest: total rows = {total_rows}")
# Pre-allocate array
data = np.empty((total_rows, 2), dtype=np.uint32)
row_ptr = 0
for cid, ntok in zip(chunk_ids, num_tokens_arr):
data[row_ptr : row_ptr + ntok, 0] = cid # chunk_id column
data[row_ptr : row_ptr + ntok, 1] = np.arange(ntok, dtype=np.uint32) # row index within chunk
row_ptr += ntok
elif file_size_bytes % 8 == 0:
# Legacy 2-field format already matches expected shape
data = np.fromfile(path, dtype=np.uint32).reshape(-1, 2)
logger.info(f"Manifest loaded (legacy 2-field format) from {path} ({data.shape[0]} rows).")
else:
logger.error(
f"Manifest file size ({file_size_bytes} bytes) is not compatible with known formats (8 or 16 bytes per row)."
)
return None
return data
except ValueError as e:
logger.error(f"Error reshaping manifest data from {path} (expected Nx2): {e}")
logger.error(f"Error parsing manifest data from {path}: {e}")
return None
except OSError as e:
logger.error(f"Error reading manifest file {path}: {e}")
Expand All @@ -117,7 +146,7 @@ def _load_norm_stats(self) -> Optional[Dict[str, Any]]:
logger.error(f"Error reading norm_stats file {path}: {e}")
return None

@lru_cache(maxsize=256)
@lru_cache(maxsize=64)
def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str):
"""Loads entire HDF5 chunk from disk and caches"""

Expand All @@ -129,14 +158,29 @@ def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str):
logger.error(f"Chunk file not found for fetch: {chunk_path}")
raise
except KeyError as e:
raise RuntimeError(f"Missing 'inputs' or 'targets' dataset in layer group '{layer_key}' of chunk {chunk_path}") from e
raise RuntimeError(
f"Missing 'inputs' or 'targets' dataset in layer group '{layer_key}' of chunk {chunk_path}"
) from e
except Exception as e:
logger.error(f"Failed to open chunk at {chunk_path}: {e}")
raise RuntimeError(f"Failed to access chunk HDF5 file: {chunk_path}") from e

def _fetch_slice(self, chunk_id: int, row_indices: np.ndarray) -> bytes:

chunk_path = self.dataset_path / f"chunk_{chunk_id}.h5"
if not chunk_path.exists():
# Fall back to .hdf5 extension (newer generator default)
alt_path = self.dataset_path / f"chunk_{chunk_id}.hdf5"
if alt_path.exists():
chunk_path = alt_path
else:
# Provide clearer error message before _open_h5 raises
logger.error(
"Chunk file for chunk_id %d not found with either .h5 or .hdf5 extension in %s",
chunk_id,
self.dataset_path,
)

hf = _open_h5(chunk_path)

try:
Expand All @@ -162,8 +206,8 @@ def _layer_sort_key(name: str) -> int:
row_indices_h5 = row_indices

for i, lk in enumerate(layer_keys):
input_data = self._load_chunk(chunk_path, lk, 'inputs')[row_indices_h5, :]
target_data = self._load_chunk(chunk_path, lk, 'targets')[row_indices_h5, :]
input_data = self._load_chunk(chunk_path, lk, "inputs")[row_indices_h5, :]
target_data = self._load_chunk(chunk_path, lk, "targets")[row_indices_h5, :]
bufs.append(input_data.tobytes())
bufs.append(target_data.tobytes())
return b"".join(bufs)
Expand Down
2 changes: 1 addition & 1 deletion clt/training/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def log_training_step(
self.metrics["train_losses"].append({"step": step, **loss_dict})

if not self.distributed or self.rank == 0:
total_tokens_processed = self.training_config.train_batch_size_tokens * self.world_size * (step + 1)
total_tokens_processed = self.training_config.train_batch_size_tokens * (step + 1)

self.wandb_logger.log_step(
step,
Expand Down
Loading