|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import os |
| 8 | +import unittest |
| 9 | +from unittest.mock import MagicMock, patch |
| 10 | + |
| 11 | +import torch |
| 12 | +from torchtitan.config import Debug as DebugConfig |
| 13 | +from torchtitan.distributed.utils import set_determinism |
| 14 | + |
| 15 | + |
| 16 | +class FakeDeviceMesh: |
| 17 | + """Fake DeviceMesh for testing seed uniqueness. |
| 18 | +
|
| 19 | + Args: |
| 20 | + mesh_dim_names: List of dimension names (e.g., ["dp", "pp", "tp"]) |
| 21 | + mesh_sizes: List of sizes for each dimension (e.g., [4, 2, 8]) |
| 22 | + rank_coords: Tuple of coordinates for this rank (e.g., (2, 1, 5)) |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, mesh_dim_names, mesh_sizes, rank_coords): |
| 26 | + self.mesh_dim_names = mesh_dim_names |
| 27 | + self.mesh_sizes = dict(zip(mesh_dim_names, mesh_sizes)) |
| 28 | + self.rank_coords = dict(zip(mesh_dim_names, rank_coords)) |
| 29 | + |
| 30 | + def __getitem__(self, key): |
| 31 | + """Return a submesh for the given dimension(s).""" |
| 32 | + if isinstance(key, str): |
| 33 | + # Single dimension |
| 34 | + submesh = MagicMock() |
| 35 | + submesh.get_local_rank.return_value = self.rank_coords[key] |
| 36 | + submesh.size.return_value = self.mesh_sizes[key] |
| 37 | + submesh.get_coordinate.return_value = self.rank_coords[key] |
| 38 | + return submesh |
| 39 | + elif isinstance(key, list): |
| 40 | + # Multiple dimensions |
| 41 | + submesh = MagicMock() |
| 42 | + # For multiple dimensions, get_coordinate should return None |
| 43 | + # since we're not testing this path |
| 44 | + submesh.get_coordinate.return_value = None |
| 45 | + return submesh |
| 46 | + else: |
| 47 | + raise ValueError(f"Unsupported key type: {type(key)}") |
| 48 | + |
| 49 | + def get_coordinate(self): |
| 50 | + """Return the coordinate tuple for this rank.""" |
| 51 | + return tuple(self.rank_coords[dim] for dim in self.mesh_dim_names) |
| 52 | + |
| 53 | + |
| 54 | +class TestSetDeterminismWithFakeMesh(unittest.TestCase): |
| 55 | + """Test set_determinism with fake mesh to verify seed uniqueness.""" |
| 56 | + |
| 57 | + def setUp(self): |
| 58 | + """Set up test fixtures.""" |
| 59 | + self.device = torch.device("cpu") |
| 60 | + |
| 61 | + def tearDown(self): |
| 62 | + """Clean up after tests.""" |
| 63 | + torch.use_deterministic_algorithms(False) |
| 64 | + if "PYTHONHASHSEED" in os.environ: |
| 65 | + del os.environ["PYTHONHASHSEED"] |
| 66 | + if "CUBLAS_WORKSPACE_CONFIG" in os.environ: |
| 67 | + del os.environ["CUBLAS_WORKSPACE_CONFIG"] |
| 68 | + |
| 69 | + @patch("torch.distributed.distributed_c10d.get_world_size") |
| 70 | + @patch("torch.distributed.distributed_c10d.get_rank") |
| 71 | + def test_seed_uniqueness_2d_mesh(self, mock_get_rank, mock_get_world_size): |
| 72 | + """Test that different PP ranks get unique seeds, same DP ranks share seeds.""" |
| 73 | + mock_get_world_size.return_value = 8 # 4 * 2 |
| 74 | + |
| 75 | + mesh_dim_names = ["dp", "pp"] |
| 76 | + mesh_sizes = [4, 2] |
| 77 | + base_seed = 1000 |
| 78 | + |
| 79 | + seeds_by_coord = {} |
| 80 | + |
| 81 | + # Test all possible rank coordinates |
| 82 | + for dp_rank in range(mesh_sizes[0]): |
| 83 | + for pp_rank in range(mesh_sizes[1]): |
| 84 | + mock_get_rank.return_value = dp_rank * mesh_sizes[1] + pp_rank |
| 85 | + |
| 86 | + # Create fake mesh for this rank |
| 87 | + rank_coords = (dp_rank, pp_rank) |
| 88 | + fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) |
| 89 | + |
| 90 | + # Call set_determinism with distinct seeds only on PP dimension |
| 91 | + debug_config = DebugConfig(seed=base_seed, deterministic=False) |
| 92 | + set_determinism( |
| 93 | + world_mesh=fake_mesh, |
| 94 | + device=self.device, |
| 95 | + debug_config=debug_config, |
| 96 | + distinct_seed_mesh_dims=["pp"], |
| 97 | + ) |
| 98 | + |
| 99 | + # Capture the seed that was set |
| 100 | + rng_state = torch.get_rng_state() |
| 101 | + actual_seed = rng_state[:8].view(torch.int64).item() |
| 102 | + |
| 103 | + # Store for verification |
| 104 | + coord_key = (dp_rank, pp_rank) |
| 105 | + seeds_by_coord[coord_key] = actual_seed |
| 106 | + |
| 107 | + # Verify that coordinates with same PP but different DP have same seed |
| 108 | + for pp_rank in range(mesh_sizes[1]): |
| 109 | + # All DP ranks should have same seed for this PP rank |
| 110 | + seeds_for_this_pp = [ |
| 111 | + seeds_by_coord[(dp_rank, pp_rank)] for dp_rank in range(mesh_sizes[0]) |
| 112 | + ] |
| 113 | + self.assertEqual( |
| 114 | + len(set(seeds_for_this_pp)), |
| 115 | + 1, |
| 116 | + f"Different DP ranks at pp={pp_rank} should have same seed, " |
| 117 | + f"got {seeds_for_this_pp}", |
| 118 | + ) |
| 119 | + |
| 120 | + # Verify that different PP ranks have different seeds |
| 121 | + unique_pp_seeds = set() |
| 122 | + for pp_rank in range(mesh_sizes[1]): |
| 123 | + seed = seeds_by_coord[(0, pp_rank)] # Just check first DP rank |
| 124 | + self.assertNotIn(seed, unique_pp_seeds, f"Duplicate seed for pp={pp_rank}") |
| 125 | + unique_pp_seeds.add(seed) |
| 126 | + |
| 127 | + self.assertEqual( |
| 128 | + len(unique_pp_seeds), |
| 129 | + mesh_sizes[1], |
| 130 | + f"Expected {mesh_sizes[1]} unique seeds for PP dimension", |
| 131 | + ) |
| 132 | + |
| 133 | + @patch("torch.distributed.distributed_c10d.get_world_size") |
| 134 | + @patch("torch.distributed.distributed_c10d.get_rank") |
| 135 | + def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size): |
| 136 | + """Test that different dp_shard and dp_replicate get unique seeds, TP shares seeds.""" |
| 137 | + mesh_dim_names = ["dp_shard", "dp_replicate", "tp"] |
| 138 | + mesh_sizes = [3, 2, 4] |
| 139 | + mock_get_world_size.return_value = 3 * 2 * 4 |
| 140 | + base_seed = 2000 |
| 141 | + |
| 142 | + seeds_by_coord = {} |
| 143 | + |
| 144 | + # Test all possible rank coordinates |
| 145 | + for dp_shard_rank in range(mesh_sizes[0]): |
| 146 | + for dp_replicate_rank in range(mesh_sizes[1]): |
| 147 | + for tp_rank in range(mesh_sizes[2]): |
| 148 | + global_rank = ( |
| 149 | + dp_shard_rank * (mesh_sizes[1] * mesh_sizes[2]) |
| 150 | + + dp_replicate_rank * mesh_sizes[2] |
| 151 | + + tp_rank |
| 152 | + ) |
| 153 | + mock_get_rank.return_value = global_rank |
| 154 | + |
| 155 | + # Create fake mesh for this rank |
| 156 | + rank_coords = (dp_shard_rank, dp_replicate_rank, tp_rank) |
| 157 | + fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) |
| 158 | + |
| 159 | + # Call set_determinism with distinct seeds on dp_shard and dp_replicate only |
| 160 | + debug_config = DebugConfig(seed=base_seed, deterministic=False) |
| 161 | + set_determinism( |
| 162 | + world_mesh=fake_mesh, |
| 163 | + device=self.device, |
| 164 | + debug_config=debug_config, |
| 165 | + distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], |
| 166 | + ) |
| 167 | + |
| 168 | + # Capture the seed that was set |
| 169 | + rng_state = torch.get_rng_state() |
| 170 | + actual_seed = rng_state[:8].view(torch.int64).item() |
| 171 | + |
| 172 | + # Store for verification |
| 173 | + coord_key = (dp_shard_rank, dp_replicate_rank, tp_rank) |
| 174 | + seeds_by_coord[coord_key] = actual_seed |
| 175 | + |
| 176 | + # Verify that coordinates with same (dp_shard, dp_replicate) but different TP have same seed |
| 177 | + for dp_shard_rank in range(mesh_sizes[0]): |
| 178 | + for dp_replicate_rank in range(mesh_sizes[1]): |
| 179 | + # All TP ranks should have same seed for this (dp_shard, dp_replicate) |
| 180 | + seeds_for_this_dp = [ |
| 181 | + seeds_by_coord[(dp_shard_rank, dp_replicate_rank, tp_rank)] |
| 182 | + for tp_rank in range(mesh_sizes[2]) |
| 183 | + ] |
| 184 | + self.assertEqual( |
| 185 | + len(set(seeds_for_this_dp)), |
| 186 | + 1, |
| 187 | + f"Different TP ranks at (dp_shard={dp_shard_rank}, dp_replicate={dp_replicate_rank}) " |
| 188 | + f"should have same seed, got {seeds_for_this_dp}", |
| 189 | + ) |
| 190 | + |
| 191 | + # Verify that different (dp_shard, dp_replicate) combinations have different seeds |
| 192 | + unique_dp_seeds = set() |
| 193 | + for dp_shard_rank in range(mesh_sizes[0]): |
| 194 | + for dp_replicate_rank in range(mesh_sizes[1]): |
| 195 | + seed = seeds_by_coord[ |
| 196 | + (dp_shard_rank, dp_replicate_rank, 0) |
| 197 | + ] # Just check first TP rank |
| 198 | + self.assertNotIn( |
| 199 | + seed, |
| 200 | + unique_dp_seeds, |
| 201 | + f"Duplicate seed for (dp_shard={dp_shard_rank}, dp_replicate={dp_replicate_rank})", |
| 202 | + ) |
| 203 | + unique_dp_seeds.add(seed) |
| 204 | + |
| 205 | + self.assertEqual( |
| 206 | + len(unique_dp_seeds), |
| 207 | + mesh_sizes[0] * mesh_sizes[1], |
| 208 | + f"Expected {mesh_sizes[0] * mesh_sizes[1]} unique seeds for (dp_shard, dp_replicate) combinations", |
| 209 | + ) |
| 210 | + |
| 211 | + |
| 212 | +if __name__ == "__main__": |
| 213 | + unittest.main() |
0 commit comments