Skip to content

Automatic global process group cleanup post training #794

@ananthsub

Description

@ananthsub

Follow up to #770

The training framework currently has asymmetric process group management that can lead to resource leakage:

  • Automatic Initialization: The training loop automatically initializes the PyTorch distributed process group if it's not already initialized:
    device_count = torch.cuda.device_count()
    if torch.distributed.is_initialized():
    if get_rank_safe() == 0:
    print(
    "torch distributed is already initialized, skipping initialization ...",
    flush=True,
    )
    else:
    if get_rank_safe() == 0:
    print("> initializing torch distributed ...", flush=True)
    # Manually set the device ids.
    if device_count > 0:
    if dist_config.external_gpu_device_mapping:
    torch.cuda.set_device(0)
    else:
    torch.cuda.set_device(get_local_rank_preinit())
    # Set to non-default stream for cudagraph capturing.
    if model_config.external_cuda_graph:
    torch.cuda.set_stream(torch.cuda.Stream())
    # Call the init process
    init_process_group_kwargs = {
    "backend": dist_config.distributed_backend,
    "world_size": get_world_size_safe(),
    "rank": get_rank_safe(),
    "store": restart_store,
    "timeout": datetime.timedelta(minutes=dist_config.distributed_timeout_minutes),
    }
    torch.distributed.init_process_group(**init_process_group_kwargs)
  • However, the framework does not explicitly destroy the process group at the end of training, leaving cleanup to users even when they didn't initialize it themselves:


    def destroy_global_state() -> None:
    """Destroy Megatron global states.
    Cleans up resources used by microbatch calculator, global memory buffer,
    model parallel groups, and the rerun state machine.
    """
    from megatron.core.rerun_state_machine import destroy_rerun_state_machine
    destroy_num_microbatches_calculator()
    parallel_state.destroy_global_memory_buffer()
    parallel_state.destroy_model_parallel()
    destroy_rerun_state_machine()

This asymmetry creates hidden behavior where users become responsible for cleaning up resources they didn't explicitly create, leading to the cleanup errors reported in issue #770.

Proposal:

Principle: The framework should restore the distributed state to how it was before training began.

  • If the framework initializes the process group → The framework should destroy it
  • If the user initialized the process group → The framework should leave it intact

Implementation sketch:
#795

Note: This would be a breaking change for users who don't explicitly create the group themselves but depend on it being available after training completes.

Note: this would not extend to the bridge use cases, as we don't have a spot afterwards to cleanup based on the potential initialization that happns here:

def initialize_model_parallel(
self, seed: int | None = None, seed_kwargs: dict | None = None, **model_parallel_kwargs
) -> None:
"""Initializes model parallelism and sets the random seed.
This is a convenience method that sets up tensor, pipeline, and other
forms of model parallelism based on the attributes of the provider instance.
Args:
seed: The random seed for model parallel RNG.
seed_kwargs: Additional arguments for `model_parallel_cuda_manual_seed`.
**model_parallel_kwargs: Additional arguments for `parallel_state.initialize_model_parallel`.
"""
if not torch.distributed.is_initialized():
torch.cuda.set_device(get_local_rank_preinit())
torch.distributed.init_process_group("nccl")

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions