Skip to content

Commit 86242b2

Browse files
committed
safe map
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent 076b33c commit 86242b2

File tree

2 files changed

+161
-11
lines changed

2 files changed

+161
-11
lines changed

src/megatron/bridge/data/datasets/utils.py

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
import os
2121
import pickle
2222
import re
23+
import signal
2324
import time
2425
from functools import lru_cache, partial
26+
from queue import Empty
2527
from typing import Any, Callable, Optional, Pattern, Type
2628

2729
import numpy as np
@@ -87,6 +89,89 @@ def build_index_from_memdata(fn, newline_int):
8789
return midx
8890

8991

92+
def safe_map(fn, iterable, workers=1, ctx="fork"):
93+
"""
94+
Crash-resilient alternative to multiprocessing.Pool.map() that can handle
95+
worker process crashes gracefully without hanging the entire operation.
96+
97+
This function provides robustness when processing large datasets where individual
98+
workers might crash. Unlike Pool.map(), it won't hang indefinitely if a worker dies.
99+
100+
Args:
101+
fn: Function to apply to each item
102+
iterable: Items to process
103+
workers: Number of worker processes
104+
ctx: Multiprocessing context ("fork", "spawn", etc.)
105+
106+
Returns:
107+
List of results (same order as iterable). Failed items are None with warning logged.
108+
"""
109+
ctx = mp.get_context(ctx)
110+
input_queue = ctx.Queue()
111+
output_queue = ctx.Queue()
112+
indexed_inputs = list(enumerate(iterable))
113+
for job in indexed_inputs:
114+
input_queue.put(job)
115+
for _ in range(workers):
116+
input_queue.put(None) # poison pill
117+
118+
def worker_loop():
119+
while True:
120+
job = input_queue.get()
121+
if job is None:
122+
break
123+
i, item = job
124+
try:
125+
result = fn(item)
126+
output_queue.put((i, True, result, None))
127+
except Exception as e:
128+
output_queue.put((i, False, None, str(e)))
129+
130+
processes = [ctx.Process(target=worker_loop) for _ in range(workers)]
131+
for p in processes:
132+
p.start()
133+
134+
results = [None] * len(indexed_inputs)
135+
seen_indices = set()
136+
expected = len(indexed_inputs)
137+
received = 0
138+
139+
# Collect whatever gets returned from live workers
140+
while received < expected:
141+
try:
142+
i, success, result, err = output_queue.get(timeout=0.5)
143+
seen_indices.add(i)
144+
results[i] = result if success else None
145+
if not success:
146+
logger.warning(f"Item {i}: {err}")
147+
received += 1
148+
except Empty:
149+
# Check if all workers are dead
150+
if all(not p.is_alive() for p in processes):
151+
logger.error("All workers exited before completing all tasks.")
152+
break
153+
continue
154+
155+
# Join and check for crashes
156+
for p in processes:
157+
p.join()
158+
if p.exitcode is not None and p.exitcode < 0:
159+
sig = -p.exitcode
160+
try:
161+
sig_name = signal.Signals(sig).name
162+
except Exception:
163+
sig_name = f"signal {sig}"
164+
logger.warning(f"PID {p.pid} died from {sig_name}")
165+
166+
# Patch any missing results from crashed workers
167+
for i in range(len(results)):
168+
if i not in seen_indices:
169+
logger.warning(f"No result for item {i}, likely crash")
170+
results[i] = None
171+
172+
return results
173+
174+
90175
class _TextMemMapDataset(Dataset):
91176
"""
92177
Allow per-line lazy access to multiple text files using numpy memmap.
@@ -568,17 +653,16 @@ def build_index_files(
568653
logger.info(f"Processing {len(dataset_paths)} data files using {workers} workers")
569654
# load all files into memmap
570655
start_time = time.time()
571-
ctx = mp.get_context("fork")
572-
with ctx.Pool(workers) as p:
573-
build_status = p.map(
574-
partial(
575-
_build_memmap_index_files,
576-
newline_int,
577-
build_index_fn,
578-
index_mapping_dir=index_mapping_dir,
579-
),
580-
dataset_paths,
581-
)
656+
build_status = safe_map(
657+
partial(
658+
_build_memmap_index_files,
659+
newline_int,
660+
build_index_fn,
661+
index_mapping_dir=index_mapping_dir,
662+
),
663+
dataset_paths,
664+
workers=workers,
665+
)
582666

583667
logger.info(
584668
f"Time building {sum(build_status)} / {len(build_status)} "

tests/unit_tests/data/datasets/test_utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_response_value_formater,
3636
build_index_from_memdata,
3737
handle_index,
38+
safe_map,
3839
)
3940

4041

@@ -308,3 +309,68 @@ def test_build_memmap_index_files_with_msc_url(self):
308309

309310
assert msc.Path(f"msc://default{temp_dir}/training.jsonl.idx.npy")
310311
assert msc.Path(f"msc://default{temp_dir}/training.jsonl.idx.info")
312+
313+
314+
class TestSafeMap:
315+
"""Test cases for crash-resilient safe_map function."""
316+
317+
def test_safe_map_basic_functionality(self):
318+
"""Test that safe_map works like normal map for successful cases."""
319+
320+
def square(x):
321+
return x * x
322+
323+
items = [1, 2, 3, 4, 5]
324+
result = safe_map(square, items, workers=2)
325+
326+
assert result == [1, 4, 9, 16, 25]
327+
328+
def test_safe_map_handles_exceptions(self):
329+
"""Test that safe_map handles exceptions gracefully."""
330+
331+
def process_with_error(x):
332+
if x == 3:
333+
raise ValueError("Simulated error")
334+
return x * 2
335+
336+
items = [1, 2, 3, 4, 5]
337+
result = safe_map(process_with_error, items, workers=2)
338+
339+
# Item 3 should be None (failed), others should succeed
340+
assert result[0] == 2
341+
assert result[1] == 4
342+
assert result[2] is None # Failed item
343+
assert result[3] == 8
344+
assert result[4] == 10
345+
346+
def test_safe_map_preserves_order(self):
347+
"""Test that safe_map preserves input order even with parallel execution."""
348+
349+
def identity(x):
350+
return x
351+
352+
items = list(range(100))
353+
result = safe_map(identity, items, workers=4)
354+
355+
assert result == items
356+
357+
def test_safe_map_with_single_worker(self):
358+
"""Test safe_map with workers=1 (sequential execution)."""
359+
360+
def double(x):
361+
return x * 2
362+
363+
items = [1, 2, 3]
364+
result = safe_map(double, items, workers=1)
365+
366+
assert result == [2, 4, 6]
367+
368+
def test_safe_map_empty_iterable(self):
369+
"""Test safe_map with empty input."""
370+
371+
def identity(x):
372+
return x
373+
374+
result = safe_map(identity, [], workers=2)
375+
376+
assert result == []

0 commit comments

Comments
 (0)