@@ -89,12 +89,41 @@ def _load_manifest(self) -> Optional[np.ndarray]:
89
89
logger .error (f"Manifest file not found: { path } " )
90
90
return None
91
91
try :
92
- with open (path , "rb" ) as f :
93
- data = np .frombuffer (f .read (), dtype = np .uint32 ).reshape (- 1 , 2 )
94
- logger .info (f"Manifest loaded from { path } ({ len (data )} rows)." )
92
+ file_size_bytes = path .stat ().st_size
93
+ # Heuristic: older 2-field format is 8 bytes per entry (two uint32),
94
+ # newer 3-field format is 16 bytes per entry (int32, int32, int64).
95
+ if file_size_bytes % 16 == 0 :
96
+ # New format with 3 fields (chunk_id, num_tokens, offset)
97
+ manifest_dtype = np .dtype ([("chunk_id" , np .int32 ), ("num_tokens" , np .int32 ), ("offset" , np .int64 )])
98
+ data_structured = np .fromfile (path , dtype = manifest_dtype )
99
+ logger .info (
100
+ f"Manifest loaded (3-field format) from { path } ({ data_structured .shape [0 ]} chunks). Expanding to per-row entries."
101
+ )
102
+ # Expand into per-row entries expected by downstream (chunk_id, row_in_chunk)
103
+ chunk_ids = data_structured ["chunk_id" ].astype (np .uint32 )
104
+ num_tokens_arr = data_structured ["num_tokens" ].astype (np .uint32 )
105
+ # Compute total rows
106
+ total_rows = int (num_tokens_arr .sum ())
107
+ logger .info (f"Expanding manifest: total rows = { total_rows } " )
108
+ # Pre-allocate array
109
+ data = np .empty ((total_rows , 2 ), dtype = np .uint32 )
110
+ row_ptr = 0
111
+ for cid , ntok in zip (chunk_ids , num_tokens_arr ):
112
+ data [row_ptr : row_ptr + ntok , 0 ] = cid # chunk_id column
113
+ data [row_ptr : row_ptr + ntok , 1 ] = np .arange (ntok , dtype = np .uint32 ) # row index within chunk
114
+ row_ptr += ntok
115
+ elif file_size_bytes % 8 == 0 :
116
+ # Legacy 2-field format already matches expected shape
117
+ data = np .fromfile (path , dtype = np .uint32 ).reshape (- 1 , 2 )
118
+ logger .info (f"Manifest loaded (legacy 2-field format) from { path } ({ data .shape [0 ]} rows)." )
119
+ else :
120
+ logger .error (
121
+ f"Manifest file size ({ file_size_bytes } bytes) is not compatible with known formats (8 or 16 bytes per row)."
122
+ )
123
+ return None
95
124
return data
96
125
except ValueError as e :
97
- logger .error (f"Error reshaping manifest data from { path } (expected Nx2) : { e } " )
126
+ logger .error (f"Error parsing manifest data from { path } : { e } " )
98
127
return None
99
128
except OSError as e :
100
129
logger .error (f"Error reading manifest file { path } : { e } " )
@@ -117,7 +146,7 @@ def _load_norm_stats(self) -> Optional[Dict[str, Any]]:
117
146
logger .error (f"Error reading norm_stats file { path } : { e } " )
118
147
return None
119
148
120
- @lru_cache (maxsize = 256 )
149
+ @lru_cache (maxsize = 64 )
121
150
def _load_chunk (self , chunk_path : str , layer_key : str , data_type : str ):
122
151
"""Loads entire HDF5 chunk from disk and caches"""
123
152
@@ -129,14 +158,29 @@ def _load_chunk(self, chunk_path: str, layer_key: str, data_type: str):
129
158
logger .error (f"Chunk file not found for fetch: { chunk_path } " )
130
159
raise
131
160
except KeyError as e :
132
- raise RuntimeError (f"Missing 'inputs' or 'targets' dataset in layer group '{ layer_key } ' of chunk { chunk_path } " ) from e
161
+ raise RuntimeError (
162
+ f"Missing 'inputs' or 'targets' dataset in layer group '{ layer_key } ' of chunk { chunk_path } "
163
+ ) from e
133
164
except Exception as e :
134
165
logger .error (f"Failed to open chunk at { chunk_path } : { e } " )
135
166
raise RuntimeError (f"Failed to access chunk HDF5 file: { chunk_path } " ) from e
136
167
137
168
def _fetch_slice (self , chunk_id : int , row_indices : np .ndarray ) -> bytes :
138
169
139
170
chunk_path = self .dataset_path / f"chunk_{ chunk_id } .h5"
171
+ if not chunk_path .exists ():
172
+ # Fall back to .hdf5 extension (newer generator default)
173
+ alt_path = self .dataset_path / f"chunk_{ chunk_id } .hdf5"
174
+ if alt_path .exists ():
175
+ chunk_path = alt_path
176
+ else :
177
+ # Provide clearer error message before _open_h5 raises
178
+ logger .error (
179
+ "Chunk file for chunk_id %d not found with either .h5 or .hdf5 extension in %s" ,
180
+ chunk_id ,
181
+ self .dataset_path ,
182
+ )
183
+
140
184
hf = _open_h5 (chunk_path )
141
185
142
186
try :
@@ -162,8 +206,8 @@ def _layer_sort_key(name: str) -> int:
162
206
row_indices_h5 = row_indices
163
207
164
208
for i , lk in enumerate (layer_keys ):
165
- input_data = self ._load_chunk (chunk_path , lk , ' inputs' )[row_indices_h5 , :]
166
- target_data = self ._load_chunk (chunk_path , lk , ' targets' )[row_indices_h5 , :]
209
+ input_data = self ._load_chunk (chunk_path , lk , " inputs" )[row_indices_h5 , :]
210
+ target_data = self ._load_chunk (chunk_path , lk , " targets" )[row_indices_h5 , :]
167
211
bufs .append (input_data .tobytes ())
168
212
bufs .append (target_data .tobytes ())
169
213
return b"" .join (bufs )
0 commit comments