Skip to content

Commit 4f4642d

Browse files
Curt TiggesCurt Tigges
authored andcommitted
improved backwards compatibility with past activation stores
1 parent 9c9aa38 commit 4f4642d

File tree

3 files changed

+35
-350
lines changed

3 files changed

+35
-350
lines changed

clt/training/data/local_activation_store.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,51 @@ def _load_manifest(self) -> Optional[np.ndarray]:
9999
logger.info(
100100
f"Manifest loaded (3-field format) from {path} ({data_structured.shape[0]} chunks). Expanding to per-row entries."
101101
)
102+
103+
# --- Consistency check against metadata --- #
104+
expected_total_tokens: Optional[int] = None
105+
if hasattr(self, "_meta") and isinstance(self._meta, dict):
106+
try:
107+
expected_total_tokens = int(self._meta.get("total_tokens", -1))
108+
except (ValueError, TypeError):
109+
expected_total_tokens = None
110+
102111
# Expand into per-row entries expected by downstream (chunk_id, row_in_chunk)
103112
chunk_ids = data_structured["chunk_id"].astype(np.uint32)
104113
num_tokens_arr = data_structured["num_tokens"].astype(np.uint32)
114+
115+
# If the sum of num_tokens does not match metadata total_tokens (when available),
116+
# this file is very likely a *legacy* per-row manifest whose byte-length happens to be
117+
# divisible by 16 (e.g. an even number of rows). In that case we fall back to the
118+
# legacy 2-field parsing logic.
119+
if expected_total_tokens is not None and expected_total_tokens > 0:
120+
parsed_total = int(num_tokens_arr.sum())
121+
if parsed_total != expected_total_tokens:
122+
logger.warning(
123+
"3-field manifest parse produced total_rows=%d but metadata reports %d tokens. "
124+
"Falling back to legacy 2-field manifest parsing.",
125+
parsed_total,
126+
expected_total_tokens,
127+
)
128+
# Legacy 2-field format already matches expected shape
129+
data = np.fromfile(path, dtype=np.uint32).reshape(-1, 2)
130+
logger.info(
131+
"Manifest re-loaded (legacy 2-field format) from %s (%d rows).",
132+
path,
133+
data.shape[0],
134+
)
135+
return data
136+
137+
# --- Proceed with 3-field expansion --- #
105138
# Compute total rows
106139
total_rows = int(num_tokens_arr.sum())
107140
logger.info(f"Expanding manifest: total rows = {total_rows}")
108141
# Pre-allocate array
109142
data = np.empty((total_rows, 2), dtype=np.uint32)
110143
row_ptr = 0
111144
for cid, ntok in zip(chunk_ids, num_tokens_arr):
145+
if ntok == 0:
146+
continue # Skip empty chunks to avoid broadcast errors
112147
data[row_ptr : row_ptr + ntok, 0] = cid # chunk_id column
113148
data[row_ptr : row_ptr + ntok, 1] = np.arange(ntok, dtype=np.uint32) # row index within chunk
114149
row_ptr += ntok

scripts/analysis/analyze_theta.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)