-
Notifications
You must be signed in to change notification settings - Fork 28
Description
Follow up to #770
The training framework currently has asymmetric process group management that can lead to resource leakage:
- Automatic Initialization: The training loop automatically initializes the PyTorch distributed process group if it's not already initialized:
Megatron-Bridge/src/megatron/bridge/training/initialize.py
Lines 335 to 367 in d42316e
device_count = torch.cuda.device_count() if torch.distributed.is_initialized(): if get_rank_safe() == 0: print( "torch distributed is already initialized, skipping initialization ...", flush=True, ) else: if get_rank_safe() == 0: print("> initializing torch distributed ...", flush=True) # Manually set the device ids. if device_count > 0: if dist_config.external_gpu_device_mapping: torch.cuda.set_device(0) else: torch.cuda.set_device(get_local_rank_preinit()) # Set to non-default stream for cudagraph capturing. if model_config.external_cuda_graph: torch.cuda.set_stream(torch.cuda.Stream()) # Call the init process init_process_group_kwargs = { "backend": dist_config.distributed_backend, "world_size": get_world_size_safe(), "rank": get_rank_safe(), "store": restart_store, "timeout": datetime.timedelta(minutes=dist_config.distributed_timeout_minutes), } torch.distributed.init_process_group(**init_process_group_kwargs) - However, the framework does not explicitly destroy the process group at the end of training, leaving cleanup to users even when they didn't initialize it themselves:
_finish_train(state)
destroy_global_state()
Megatron-Bridge/src/megatron/bridge/training/initialize.py
Lines 259 to 270 in d42316e
def destroy_global_state() -> None: """Destroy Megatron global states. Cleans up resources used by microbatch calculator, global memory buffer, model parallel groups, and the rerun state machine. """ from megatron.core.rerun_state_machine import destroy_rerun_state_machine destroy_num_microbatches_calculator() parallel_state.destroy_global_memory_buffer() parallel_state.destroy_model_parallel() destroy_rerun_state_machine()
This asymmetry creates hidden behavior where users become responsible for cleaning up resources they didn't explicitly create, leading to the cleanup errors reported in issue #770.
Proposal:
Principle: The framework should restore the distributed state to how it was before training began.
- If the framework initializes the process group → The framework should destroy it
- If the user initialized the process group → The framework should leave it intact
Implementation sketch:
#795
Note: This would be a breaking change for users who don't explicitly create the group themselves but depend on it being available after training completes.
Note: this would not extend to the bridge use cases, as we don't have a spot afterwards to cleanup based on the potential initialization that happns here:
Megatron-Bridge/src/megatron/bridge/models/model_provider.py
Lines 199 to 215 in d42316e
def initialize_model_parallel( | |
self, seed: int | None = None, seed_kwargs: dict | None = None, **model_parallel_kwargs | |
) -> None: | |
"""Initializes model parallelism and sets the random seed. | |
This is a convenience method that sets up tensor, pipeline, and other | |
forms of model parallelism based on the attributes of the provider instance. | |
Args: | |
seed: The random seed for model parallel RNG. | |
seed_kwargs: Additional arguments for `model_parallel_cuda_manual_seed`. | |
**model_parallel_kwargs: Additional arguments for `parallel_state.initialize_model_parallel`. | |
""" | |
if not torch.distributed.is_initialized(): | |
torch.cuda.set_device(get_local_rank_preinit()) | |
torch.distributed.init_process_group("nccl") | |