|
20 | 20 | import os
|
21 | 21 | import pickle
|
22 | 22 | import re
|
| 23 | +import signal |
23 | 24 | import time
|
24 | 25 | from functools import lru_cache, partial
|
| 26 | +from queue import Empty |
25 | 27 | from typing import Any, Callable, Optional, Pattern, Type
|
26 | 28 |
|
27 | 29 | import numpy as np
|
@@ -87,6 +89,89 @@ def build_index_from_memdata(fn, newline_int):
|
87 | 89 | return midx
|
88 | 90 |
|
89 | 91 |
|
| 92 | +def safe_map(fn, iterable, workers=1, ctx="fork"): |
| 93 | + """ |
| 94 | + Crash-resilient alternative to multiprocessing.Pool.map() that can handle |
| 95 | + worker process crashes gracefully without hanging the entire operation. |
| 96 | +
|
| 97 | + This function provides robustness when processing large datasets where individual |
| 98 | + workers might crash. Unlike Pool.map(), it won't hang indefinitely if a worker dies. |
| 99 | +
|
| 100 | + Args: |
| 101 | + fn: Function to apply to each item |
| 102 | + iterable: Items to process |
| 103 | + workers: Number of worker processes |
| 104 | + ctx: Multiprocessing context ("fork", "spawn", etc.) |
| 105 | +
|
| 106 | + Returns: |
| 107 | + List of results (same order as iterable). Failed items are None with warning logged. |
| 108 | + """ |
| 109 | + ctx = mp.get_context(ctx) |
| 110 | + input_queue = ctx.Queue() |
| 111 | + output_queue = ctx.Queue() |
| 112 | + indexed_inputs = list(enumerate(iterable)) |
| 113 | + for job in indexed_inputs: |
| 114 | + input_queue.put(job) |
| 115 | + for _ in range(workers): |
| 116 | + input_queue.put(None) # poison pill |
| 117 | + |
| 118 | + def worker_loop(): |
| 119 | + while True: |
| 120 | + job = input_queue.get() |
| 121 | + if job is None: |
| 122 | + break |
| 123 | + i, item = job |
| 124 | + try: |
| 125 | + result = fn(item) |
| 126 | + output_queue.put((i, True, result, None)) |
| 127 | + except Exception as e: |
| 128 | + output_queue.put((i, False, None, str(e))) |
| 129 | + |
| 130 | + processes = [ctx.Process(target=worker_loop) for _ in range(workers)] |
| 131 | + for p in processes: |
| 132 | + p.start() |
| 133 | + |
| 134 | + results = [None] * len(indexed_inputs) |
| 135 | + seen_indices = set() |
| 136 | + expected = len(indexed_inputs) |
| 137 | + received = 0 |
| 138 | + |
| 139 | + # Collect whatever gets returned from live workers |
| 140 | + while received < expected: |
| 141 | + try: |
| 142 | + i, success, result, err = output_queue.get(timeout=0.5) |
| 143 | + seen_indices.add(i) |
| 144 | + results[i] = result if success else None |
| 145 | + if not success: |
| 146 | + logger.warning(f"Item {i}: {err}") |
| 147 | + received += 1 |
| 148 | + except Empty: |
| 149 | + # Check if all workers are dead |
| 150 | + if all(not p.is_alive() for p in processes): |
| 151 | + logger.error("All workers exited before completing all tasks.") |
| 152 | + break |
| 153 | + continue |
| 154 | + |
| 155 | + # Join and check for crashes |
| 156 | + for p in processes: |
| 157 | + p.join() |
| 158 | + if p.exitcode is not None and p.exitcode < 0: |
| 159 | + sig = -p.exitcode |
| 160 | + try: |
| 161 | + sig_name = signal.Signals(sig).name |
| 162 | + except Exception: |
| 163 | + sig_name = f"signal {sig}" |
| 164 | + logger.warning(f"PID {p.pid} died from {sig_name}") |
| 165 | + |
| 166 | + # Patch any missing results from crashed workers |
| 167 | + for i in range(len(results)): |
| 168 | + if i not in seen_indices: |
| 169 | + logger.warning(f"No result for item {i}, likely crash") |
| 170 | + results[i] = None |
| 171 | + |
| 172 | + return results |
| 173 | + |
| 174 | + |
90 | 175 | class _TextMemMapDataset(Dataset):
|
91 | 176 | """
|
92 | 177 | Allow per-line lazy access to multiple text files using numpy memmap.
|
@@ -568,17 +653,16 @@ def build_index_files(
|
568 | 653 | logger.info(f"Processing {len(dataset_paths)} data files using {workers} workers")
|
569 | 654 | # load all files into memmap
|
570 | 655 | start_time = time.time()
|
571 |
| - ctx = mp.get_context("fork") |
572 |
| - with ctx.Pool(workers) as p: |
573 |
| - build_status = p.map( |
574 |
| - partial( |
575 |
| - _build_memmap_index_files, |
576 |
| - newline_int, |
577 |
| - build_index_fn, |
578 |
| - index_mapping_dir=index_mapping_dir, |
579 |
| - ), |
580 |
| - dataset_paths, |
581 |
| - ) |
| 656 | + build_status = safe_map( |
| 657 | + partial( |
| 658 | + _build_memmap_index_files, |
| 659 | + newline_int, |
| 660 | + build_index_fn, |
| 661 | + index_mapping_dir=index_mapping_dir, |
| 662 | + ), |
| 663 | + dataset_paths, |
| 664 | + workers=workers, |
| 665 | + ) |
582 | 666 |
|
583 | 667 | logger.info(
|
584 | 668 | f"Time building {sum(build_status)} / {len(build_status)} "
|
|
0 commit comments