@@ -99,16 +99,51 @@ def _load_manifest(self) -> Optional[np.ndarray]:
99
99
logger .info (
100
100
f"Manifest loaded (3-field format) from { path } ({ data_structured .shape [0 ]} chunks). Expanding to per-row entries."
101
101
)
102
+
103
+ # --- Consistency check against metadata --- #
104
+ expected_total_tokens : Optional [int ] = None
105
+ if hasattr (self , "_meta" ) and isinstance (self ._meta , dict ):
106
+ try :
107
+ expected_total_tokens = int (self ._meta .get ("total_tokens" , - 1 ))
108
+ except (ValueError , TypeError ):
109
+ expected_total_tokens = None
110
+
102
111
# Expand into per-row entries expected by downstream (chunk_id, row_in_chunk)
103
112
chunk_ids = data_structured ["chunk_id" ].astype (np .uint32 )
104
113
num_tokens_arr = data_structured ["num_tokens" ].astype (np .uint32 )
114
+
115
+ # If the sum of num_tokens does not match metadata total_tokens (when available),
116
+ # this file is very likely a *legacy* per-row manifest whose byte-length happens to be
117
+ # divisible by 16 (e.g. an even number of rows). In that case we fall back to the
118
+ # legacy 2-field parsing logic.
119
+ if expected_total_tokens is not None and expected_total_tokens > 0 :
120
+ parsed_total = int (num_tokens_arr .sum ())
121
+ if parsed_total != expected_total_tokens :
122
+ logger .warning (
123
+ "3-field manifest parse produced total_rows=%d but metadata reports %d tokens. "
124
+ "Falling back to legacy 2-field manifest parsing." ,
125
+ parsed_total ,
126
+ expected_total_tokens ,
127
+ )
128
+ # Legacy 2-field format already matches expected shape
129
+ data = np .fromfile (path , dtype = np .uint32 ).reshape (- 1 , 2 )
130
+ logger .info (
131
+ "Manifest re-loaded (legacy 2-field format) from %s (%d rows)." ,
132
+ path ,
133
+ data .shape [0 ],
134
+ )
135
+ return data
136
+
137
+ # --- Proceed with 3-field expansion --- #
105
138
# Compute total rows
106
139
total_rows = int (num_tokens_arr .sum ())
107
140
logger .info (f"Expanding manifest: total rows = { total_rows } " )
108
141
# Pre-allocate array
109
142
data = np .empty ((total_rows , 2 ), dtype = np .uint32 )
110
143
row_ptr = 0
111
144
for cid , ntok in zip (chunk_ids , num_tokens_arr ):
145
+ if ntok == 0 :
146
+ continue # Skip empty chunks to avoid broadcast errors
112
147
data [row_ptr : row_ptr + ntok , 0 ] = cid # chunk_id column
113
148
data [row_ptr : row_ptr + ntok , 1 ] = np .arange (ntok , dtype = np .uint32 ) # row index within chunk
114
149
row_ptr += ntok
0 commit comments