Skip to content

Commit d18c3e6

Browse files
ananthsubpaul-gibbons
authored andcommitted
Conditionally destroy process group at end of train (NVIDIA-NeMo#795)
* cleanup process group at end of train Signed-off-by: Ananth Subramaniam <[email protected]> * barrier Signed-off-by: Ananth Subramaniam <[email protected]> * self-contained within pretrain Signed-off-by: Ananth Subramaniam <[email protected]> --------- Signed-off-by: Ananth Subramaniam <[email protected]> Signed-off-by: Paul Gibbons <[email protected]>
1 parent 1080972 commit d18c3e6

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

src/megatron/bridge/training/pretrain.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def _pretrain(
106106
store: Optional distributed Store used by in-process restart for coordination
107107
inprocess_call_wrapper: Optional wrapper injected by nvrx to expose restart iteration
108108
"""
109+
# Determine whether the training loop will initialize the process group
110+
# If the trainer creates the process group, the trainer should destroy it before returning control back to the user
111+
should_destroy_process_group = not dist.is_initialized()
112+
109113
# Handle in-process restart store prefix
110114
if inprocess_call_wrapper is not None:
111115
restart_attempt = inprocess_call_wrapper.iteration
@@ -183,3 +187,15 @@ def _pretrain(
183187
)
184188

185189
_finish_train(state)
190+
_maybe_destroy_process_group(should_destroy_process_group)
191+
192+
193+
def _maybe_destroy_process_group(should_destroy: bool) -> None:
194+
"""Destroy the process group if it was created by this training session.
195+
196+
Args:
197+
should_destroy: Whether the process group should be destroyed
198+
"""
199+
if should_destroy and dist.is_initialized():
200+
dist.barrier()
201+
dist.destroy_process_group()
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for pretrain module process group cleanup."""
16+
17+
from unittest.mock import patch
18+
19+
from megatron.bridge.training.pretrain import _maybe_destroy_process_group
20+
21+
22+
class TestDestroyProcessGroupIfNeeded:
23+
"""Test process group destruction logic."""
24+
25+
@patch("megatron.bridge.training.pretrain.dist")
26+
def test_destroy_when_should_destroy_and_initialized(self, mock_dist):
27+
"""Test process group is destroyed when both conditions are met."""
28+
mock_dist.is_initialized.return_value = True
29+
30+
_maybe_destroy_process_group(should_destroy=True)
31+
32+
mock_dist.barrier.assert_called_once()
33+
mock_dist.destroy_process_group.assert_called_once()
34+
35+
@patch("megatron.bridge.training.pretrain.dist")
36+
def test_no_destroy_when_should_not_destroy(self, mock_dist):
37+
"""Test no destruction when should_destroy is False."""
38+
mock_dist.is_initialized.return_value = True
39+
40+
_maybe_destroy_process_group(should_destroy=False)
41+
42+
mock_dist.barrier.assert_not_called()
43+
mock_dist.destroy_process_group.assert_not_called()
44+
45+
@patch("megatron.bridge.training.pretrain.dist")
46+
def test_no_destroy_when_not_initialized(self, mock_dist):
47+
"""Test no destruction when process group is not initialized."""
48+
mock_dist.is_initialized.return_value = False
49+
50+
_maybe_destroy_process_group(should_destroy=True)
51+
52+
mock_dist.barrier.assert_not_called()
53+
mock_dist.destroy_process_group.assert_not_called()
54+
55+
@patch("megatron.bridge.training.pretrain.dist")
56+
def test_no_destroy_when_neither_condition_met(self, mock_dist):
57+
"""Test no destruction when both conditions are false."""
58+
mock_dist.is_initialized.return_value = False
59+
60+
_maybe_destroy_process_group(should_destroy=False)
61+
62+
mock_dist.barrier.assert_not_called()
63+
mock_dist.destroy_process_group.assert_not_called()

0 commit comments

Comments
 (0)