Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Aug 29, 2025

Summary
This PR utilizes the latest APIs provided by DeviceMesh to simplify the creation of all different meshes.

The design philosophy is as follow:

  1. Create one world mesh with the shape as [world_size,]
  2. Create all 1-D submeshes by using 1) unflattening from the world mesh, or 2) slicing and flatten from other derived meshes.
  3. ParallelDims now provides an API, get_mesh(), which accepts str or list[str]. When the argument is str, the API directly return the corresponding 1-D submesh. If the argument is list[str], the dim names will be used to concatenate to form a n-D device mesh.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 29, 2025
@fegin fegin force-pushed the chienchin/new_device_mesh branch 7 times, most recently from 12eca61 to 19e4a23 Compare October 15, 2025 20:39

return mesh
if self._meshes[dim].size() == 1:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this will break user expectation. We got asks that DTensor redistribute running on a mesh of size 1 should perform no op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But even for current TorchTitan, we won't create any DeviceMesh if the parallelism degree is 1. So it is unclear to me how DeviceMesh with size 1 exists?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in torchtitan, in internal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch? Then it is okay, right? DeviceMesh still supports the case but TorchTitan makes a stronger assumption in our use case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

six

@fegin fegin force-pushed the chienchin/new_device_mesh branch from 19e4a23 to 178bc11 Compare October 28, 2025 20:34
@fegin fegin marked this pull request as ready for review October 28, 2025 21:01
@fegin fegin requested a review from wwwjn as a code owner October 28, 2025 21:01
fsdp = self.dp_shard * self.cp
efsdp = fsdp * self.tp // (self.etp * self.ep)

self._world_mesh = init_device_mesh(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this initialize a world PG?

it may be fine to just ignore this for now in torchtitan, but, i am wondering if users want control over world group creation what would that look like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc., @fduwjj are we able to disable the global PG initialization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so right now we don't use split, so we can make it a fake pg. But if split is needed then we need to materialize the world PG anyway.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should modify FLUX train.py as it's in core now.

@ruisizhang123 let's adapt SimpleFSDP after this PR is merged.
oh it seems being fixed in #1959


return mesh
if self._meshes[dim].size() == 1:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in torchtitan, in internal

self._world_mesh = init_device_mesh(
device_type, (self.world_size,), mesh_dim_names=("world",)
)
dataloading_mesh = unflatten_mesh(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what will happen if self.pp * batch * self.cp * self.tp != world_size? Will the _unflatten() fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will fail

@fegin fegin force-pushed the chienchin/new_device_mesh branch from 99c46dc to d6eae58 Compare November 7, 2025 07:20
# reductions are performed during backward.
routed_input = DTensor.from_local(
routed_input, device_mesh["tp"], (Replicate(),)
routed_input, device_mesh["etp"], (Replicate(),)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be "tp".

# IF PP is also used, this seed is unique per PP rank.
if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None:
torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh)
# TODO: remove the need of duplicate_seed_meshes once torch.distributed.tensor._random.manual_seed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this on your list? @wconstab

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not quite sure what 'duplicate_seed_meshes' is supposed to mean. anyway, the plan is to change manual_seed from taking a mesh to taking a device. the device (id) is the only thing the DTensor OffsetBasedRNGTracker needs from the mesh, it doesn't use the mesh for other purposes. It would look worse but it would be functionally fine to just pass 'world_mesh' to manual_seed for all cases, since, all meshes on one process share the same device. As for making the change, i'm happy to do it, but @fegin said he was going to do it so I did not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking back on this, probably what happened is that I wrote that duplicate_seed_mesh code because I didn't understand manual_seed api at the time, and i assumed that it was important to pass the correct mesh into each api call. I think we just need to keep track of the distinct_mesh_dims part and can simplify this code now.

assert hasattr(transformer_block, "moe")
if (
dp_mod_ep_mesh.size() * parallel_dims.ep
edp_mesh.size() * parallel_dims.ep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic seems wrong before this PR.
Here it should be efsdp_mesh because we only do sharding on efsdp not dp_replicate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar for the FSDP2 application in llama4

dp_mesh_dim_names = ["dp_replicate", "efsdp"]
else:
dp_mesh_dim_names = ["efsdp"]
edp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edp sounds the right name here

Comment on lines +166 to +168
"dp_replicate_fsdp": hsdp_mesh,
"dp_replicate_efsdp": ehsdp_mesh,
"ep_etp": ep_etp_mesh,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems only "ep_etp" is used explicitly, the other two are not. I think we should be consistent -- e.g. we can disallow 2D slicing using get_mesh. Every time we use "dp_replicate_fsdp", we ask user to send in "hsdp", instead of ["dp_replicate", "fsdp"]. This aligns with the requirement of predefining everything in parallel_dims.py.

My guess of the motivation of erroring out if it's not pre-created n-D sliced mesh is because

  • we didn't keep references to all the global meshes (e.g. dense_mesh, sparse_mesh) which we are going to slice submeshes from, so without concatenate we don't know where to slice from. I think this is workaroundable by keeping references to all global meshes.
  • I had the comment that concatenate is too powerful.

I think a concern of the current approach is that user may not be able to extend to other fancy nD submesh use cases without modifying parallel_dims.py, which they probably don't need to in most cases. If they are developing outside torchtitan, they may use concatenate anyway.

In short, I think

  • if we are going with the pre-defining everything approach, we can ban get_mesh with list inputs.
  • if we are going with the flexible approach, we need to keep references to global meshes to look up from.

Either is fine to me.


# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
dp_mod_ep_mesh_dim_names = []
dp_mod_ep_mesh = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should change the name of this parameter, to edp?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if parallel_dims.ep_enabled
else None
),
dp_mod_ep_mesh=dp_mod_ep_mesh,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here, and all other occurrences.


return mesh
if self._meshes[dim].size() == 1:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

six

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants