Skip to content

Commit ff07852

Browse files
authored
[Flux] Enable unique random seed for multiple ranks (#1946)
Now when enabling HSDP, the `dp_replicate` dim will have the same seed across all ranks, which is bad for noise generation in Flux training. Fix this my changing set_determinism.
1 parent 29eb910 commit ff07852

File tree

5 files changed

+260
-19
lines changed

5 files changed

+260
-19
lines changed

scripts/generate/test_generate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,12 @@ def test_generate(
134134
apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"])
135135

136136
debug_config = DebugConfig(seed=seed, deterministic=deterministic)
137-
dist_utils.set_determinism(world_mesh, device, debug_config)
137+
dist_utils.set_determinism(
138+
world_mesh=world_mesh,
139+
device=device,
140+
debug_config=debug_config,
141+
distinct_seed_mesh_dims=["pp"],
142+
)
138143

139144
# materalize model
140145
model.to_empty(device=device_type)
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import unittest
9+
from unittest.mock import MagicMock, patch
10+
11+
import torch
12+
from torchtitan.config import Debug as DebugConfig
13+
from torchtitan.distributed.utils import set_determinism
14+
15+
16+
class FakeDeviceMesh:
17+
"""Fake DeviceMesh for testing seed uniqueness.
18+
19+
Args:
20+
mesh_dim_names: List of dimension names (e.g., ["dp", "pp", "tp"])
21+
mesh_sizes: List of sizes for each dimension (e.g., [4, 2, 8])
22+
rank_coords: Tuple of coordinates for this rank (e.g., (2, 1, 5))
23+
"""
24+
25+
def __init__(self, mesh_dim_names, mesh_sizes, rank_coords):
26+
self.mesh_dim_names = mesh_dim_names
27+
self.mesh_sizes = dict(zip(mesh_dim_names, mesh_sizes))
28+
self.rank_coords = dict(zip(mesh_dim_names, rank_coords))
29+
30+
def __getitem__(self, key):
31+
"""Return a submesh for the given dimension(s)."""
32+
if isinstance(key, str):
33+
# Single dimension
34+
submesh = MagicMock()
35+
submesh.get_local_rank.return_value = self.rank_coords[key]
36+
submesh.size.return_value = self.mesh_sizes[key]
37+
submesh.get_coordinate.return_value = self.rank_coords[key]
38+
return submesh
39+
elif isinstance(key, list):
40+
# Multiple dimensions
41+
submesh = MagicMock()
42+
# For multiple dimensions, get_coordinate should return None
43+
# since we're not testing this path
44+
submesh.get_coordinate.return_value = None
45+
return submesh
46+
else:
47+
raise ValueError(f"Unsupported key type: {type(key)}")
48+
49+
def get_coordinate(self):
50+
"""Return the coordinate tuple for this rank."""
51+
return tuple(self.rank_coords[dim] for dim in self.mesh_dim_names)
52+
53+
54+
class TestSetDeterminismWithFakeMesh(unittest.TestCase):
55+
"""Test set_determinism with fake mesh to verify seed uniqueness."""
56+
57+
def setUp(self):
58+
"""Set up test fixtures."""
59+
self.device = torch.device("cpu")
60+
61+
def tearDown(self):
62+
"""Clean up after tests."""
63+
torch.use_deterministic_algorithms(False)
64+
if "PYTHONHASHSEED" in os.environ:
65+
del os.environ["PYTHONHASHSEED"]
66+
if "CUBLAS_WORKSPACE_CONFIG" in os.environ:
67+
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
68+
69+
@patch("torch.distributed.distributed_c10d.get_world_size")
70+
@patch("torch.distributed.distributed_c10d.get_rank")
71+
def test_seed_uniqueness_2d_mesh(self, mock_get_rank, mock_get_world_size):
72+
"""Test that different PP ranks get unique seeds, same DP ranks share seeds."""
73+
mock_get_world_size.return_value = 8 # 4 * 2
74+
75+
mesh_dim_names = ["dp", "pp"]
76+
mesh_sizes = [4, 2]
77+
base_seed = 1000
78+
79+
seeds_by_coord = {}
80+
81+
# Test all possible rank coordinates
82+
for dp_rank in range(mesh_sizes[0]):
83+
for pp_rank in range(mesh_sizes[1]):
84+
mock_get_rank.return_value = dp_rank * mesh_sizes[1] + pp_rank
85+
86+
# Create fake mesh for this rank
87+
rank_coords = (dp_rank, pp_rank)
88+
fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords)
89+
90+
# Call set_determinism with distinct seeds only on PP dimension
91+
debug_config = DebugConfig(seed=base_seed, deterministic=False)
92+
set_determinism(
93+
world_mesh=fake_mesh,
94+
device=self.device,
95+
debug_config=debug_config,
96+
distinct_seed_mesh_dims=["pp"],
97+
)
98+
99+
# Capture the seed that was set
100+
rng_state = torch.get_rng_state()
101+
actual_seed = rng_state[:8].view(torch.int64).item()
102+
103+
# Store for verification
104+
coord_key = (dp_rank, pp_rank)
105+
seeds_by_coord[coord_key] = actual_seed
106+
107+
# Verify that coordinates with same PP but different DP have same seed
108+
for pp_rank in range(mesh_sizes[1]):
109+
# All DP ranks should have same seed for this PP rank
110+
seeds_for_this_pp = [
111+
seeds_by_coord[(dp_rank, pp_rank)] for dp_rank in range(mesh_sizes[0])
112+
]
113+
self.assertEqual(
114+
len(set(seeds_for_this_pp)),
115+
1,
116+
f"Different DP ranks at pp={pp_rank} should have same seed, "
117+
f"got {seeds_for_this_pp}",
118+
)
119+
120+
# Verify that different PP ranks have different seeds
121+
unique_pp_seeds = set()
122+
for pp_rank in range(mesh_sizes[1]):
123+
seed = seeds_by_coord[(0, pp_rank)] # Just check first DP rank
124+
self.assertNotIn(seed, unique_pp_seeds, f"Duplicate seed for pp={pp_rank}")
125+
unique_pp_seeds.add(seed)
126+
127+
self.assertEqual(
128+
len(unique_pp_seeds),
129+
mesh_sizes[1],
130+
f"Expected {mesh_sizes[1]} unique seeds for PP dimension",
131+
)
132+
133+
@patch("torch.distributed.distributed_c10d.get_world_size")
134+
@patch("torch.distributed.distributed_c10d.get_rank")
135+
def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size):
136+
"""Test that different dp_shard and dp_replicate get unique seeds, TP shares seeds."""
137+
mesh_dim_names = ["dp_shard", "dp_replicate", "tp"]
138+
mesh_sizes = [3, 2, 4]
139+
mock_get_world_size.return_value = 3 * 2 * 4
140+
base_seed = 2000
141+
142+
seeds_by_coord = {}
143+
144+
# Test all possible rank coordinates
145+
for dp_shard_rank in range(mesh_sizes[0]):
146+
for dp_replicate_rank in range(mesh_sizes[1]):
147+
for tp_rank in range(mesh_sizes[2]):
148+
global_rank = (
149+
dp_shard_rank * (mesh_sizes[1] * mesh_sizes[2])
150+
+ dp_replicate_rank * mesh_sizes[2]
151+
+ tp_rank
152+
)
153+
mock_get_rank.return_value = global_rank
154+
155+
# Create fake mesh for this rank
156+
rank_coords = (dp_shard_rank, dp_replicate_rank, tp_rank)
157+
fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords)
158+
159+
# Call set_determinism with distinct seeds on dp_shard and dp_replicate only
160+
debug_config = DebugConfig(seed=base_seed, deterministic=False)
161+
set_determinism(
162+
world_mesh=fake_mesh,
163+
device=self.device,
164+
debug_config=debug_config,
165+
distinct_seed_mesh_dims=["dp_shard", "dp_replicate"],
166+
)
167+
168+
# Capture the seed that was set
169+
rng_state = torch.get_rng_state()
170+
actual_seed = rng_state[:8].view(torch.int64).item()
171+
172+
# Store for verification
173+
coord_key = (dp_shard_rank, dp_replicate_rank, tp_rank)
174+
seeds_by_coord[coord_key] = actual_seed
175+
176+
# Verify that coordinates with same (dp_shard, dp_replicate) but different TP have same seed
177+
for dp_shard_rank in range(mesh_sizes[0]):
178+
for dp_replicate_rank in range(mesh_sizes[1]):
179+
# All TP ranks should have same seed for this (dp_shard, dp_replicate)
180+
seeds_for_this_dp = [
181+
seeds_by_coord[(dp_shard_rank, dp_replicate_rank, tp_rank)]
182+
for tp_rank in range(mesh_sizes[2])
183+
]
184+
self.assertEqual(
185+
len(set(seeds_for_this_dp)),
186+
1,
187+
f"Different TP ranks at (dp_shard={dp_shard_rank}, dp_replicate={dp_replicate_rank}) "
188+
f"should have same seed, got {seeds_for_this_dp}",
189+
)
190+
191+
# Verify that different (dp_shard, dp_replicate) combinations have different seeds
192+
unique_dp_seeds = set()
193+
for dp_shard_rank in range(mesh_sizes[0]):
194+
for dp_replicate_rank in range(mesh_sizes[1]):
195+
seed = seeds_by_coord[
196+
(dp_shard_rank, dp_replicate_rank, 0)
197+
] # Just check first TP rank
198+
self.assertNotIn(
199+
seed,
200+
unique_dp_seeds,
201+
f"Duplicate seed for (dp_shard={dp_shard_rank}, dp_replicate={dp_replicate_rank})",
202+
)
203+
unique_dp_seeds.add(seed)
204+
205+
self.assertEqual(
206+
len(unique_dp_seeds),
207+
mesh_sizes[0] * mesh_sizes[1],
208+
f"Expected {mesh_sizes[0] * mesh_sizes[1]} unique seeds for (dp_shard, dp_replicate) combinations",
209+
)
210+
211+
212+
if __name__ == "__main__":
213+
unittest.main()

torchtitan/distributed/utils.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,24 @@ def set_determinism(
8484
world_mesh: DeviceMesh | None,
8585
device: torch.device,
8686
debug_config: DebugConfig,
87-
distinct_seed_mesh_dim: str = "pp",
87+
distinct_seed_mesh_dims: list[str],
8888
) -> None:
8989
"""
9090
Set the same DTensor manual seed for all dimensions in world mesh, but only different seeds
91-
across dimension denoted by `distinct_seed_mesh_dim`. An example use case is pipeline parallelism,
91+
across dimensions denoted by `distinct_seed_mesh_dims`. An example use case is pipeline parallelism,
9292
where we want to have the same seed across SPMD groups, but different seeds across PP groups.
9393
9494
Currently, does not set seeds for the CUDA RNG since TorchTitan always uses DTensor for SPMD parallelisms,
9595
and DTensor manages its own RNG tracker, but we could extend to support both if needed.
9696
9797
Set Determinism flags for increased reproducibility with loss of performance.
98+
99+
Args:
100+
world_mesh: Device mesh for distributed training
101+
device: Device to use
102+
distinct_seed_mesh_dims: List of mesh dimension names to have distinct seeds across.
103+
seed: Base seed value (if None, will be determined automatically)
104+
deterministic: Whether to enable deterministic algorithms
98105
"""
99106
if debug_config.deterministic:
100107
logger.info("Deterministic algorithm enabled (expect perf degradation).")
@@ -133,28 +140,43 @@ def set_determinism(
133140
torch.distributed.broadcast(seed_tensor, src=0)
134141
seed = seed_tensor.to("cpu").view(torch.uint64).item()
135142

136-
# Set distinct seed for each rank in mesh dimensions, with dimension name provided by `distinct_seed_mesh_dim`
143+
# Set distinct seed for each rank in mesh dimensions, with dimension names provided by `distinct_seed_mesh_dims`
137144
# For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh,
138145
# and choose a unique seed for each rank on the PP mesh.
139-
# TODO(jianiw): We could further extend this to support multiple distinct dimensions instead of just one.
140-
if (
141-
c10d.get_world_size() > 1
142-
and distinct_seed_mesh_dim in world_mesh.mesh_dim_names
143-
):
144-
distinct_mesh = world_mesh[distinct_seed_mesh_dim]
145-
seed += distinct_mesh.get_local_rank()
146+
# We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed.
147+
distinct_dims_in_mesh = [
148+
dim for dim in distinct_seed_mesh_dims if dim in world_mesh.mesh_dim_names
149+
]
150+
151+
if c10d.get_world_size() > 1 and distinct_dims_in_mesh:
152+
# Each dimension contributes: local_rank * (product of all previous dimension sizes)
153+
# This guarantees uniqueness like multi-dimensional array indexing
154+
seed_offset = 0
155+
cumulative_size = 1
156+
157+
for dim in distinct_dims_in_mesh:
158+
distinct_mesh = world_mesh[dim]
159+
local_rank = distinct_mesh.get_local_rank()
160+
# Add contribution from this dimension
161+
seed_offset += local_rank * cumulative_size
162+
# Update cumulative size for next dimension
163+
cumulative_size *= distinct_mesh.size()
164+
165+
seed += seed_offset
146166
seed %= 2**64
147167

148168
logger.debug(
149-
f"{distinct_seed_mesh_dim} rank {distinct_mesh.get_local_rank()}, Global rank {c10d.get_rank()} using seed: {seed}"
150-
)
151-
duplicate_seed_mesh = list(
152-
filter(
153-
lambda name: name != distinct_seed_mesh_dim, world_mesh.mesh_dim_names
154-
)
169+
f"Distinct dims {distinct_dims_in_mesh}, Global rank {c10d.get_rank()} using seed: {seed}"
155170
)
171+
172+
# Filter out all distinct dimensions to get duplicate_seed_mesh
173+
duplicate_seed_mesh_dims = [
174+
name
175+
for name in world_mesh.mesh_dim_names
176+
if name not in distinct_dims_in_mesh
177+
]
156178
duplicate_seed_mesh = (
157-
world_mesh[duplicate_seed_mesh] if len(duplicate_seed_mesh) else None
179+
world_mesh[duplicate_seed_mesh_dims] if duplicate_seed_mesh_dims else None
158180
)
159181
else:
160182
duplicate_seed_mesh = world_mesh

torchtitan/models/flux/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(self, job_config: JobConfig):
3535
self.parallel_dims.world_mesh,
3636
self.device,
3737
job_config.debug,
38-
distinct_seed_mesh_dim="dp_shard",
38+
distinct_seed_mesh_dims=["dp_shard", "dp_replicate"],
3939
)
4040

4141
# NOTE: self._dtype is the data type used for encoders (image encoder, T5 text encoder, CLIP text encoder).

torchtitan/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(self, job_config: JobConfig):
120120
world_mesh,
121121
self.device,
122122
job_config.debug,
123+
distinct_seed_mesh_dims=["pp"],
123124
)
124125
self.train_spec = train_spec_module.get_train_spec(job_config.model.name)
125126

0 commit comments

Comments
 (0)