-
Notifications
You must be signed in to change notification settings - Fork 598
Use new DeviceMesh unflatten to rewrite parallel_dims #1660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
12eca61 to
19e4a23
Compare
|
|
||
| return mesh | ||
| if self._meshes[dim].size() == 1: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this will break user expectation. We got asks that DTensor redistribute running on a mesh of size 1 should perform no op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But even for current TorchTitan, we won't create any DeviceMesh if the parallelism degree is 1. So it is unclear to me how DeviceMesh with size 1 exists?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not in torchtitan, in internal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch? Then it is okay, right? DeviceMesh still supports the case but TorchTitan makes a stronger assumption in our use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
six
19e4a23 to
178bc11
Compare
| fsdp = self.dp_shard * self.cp | ||
| efsdp = fsdp * self.tp // (self.etp * self.ep) | ||
|
|
||
| self._world_mesh = init_device_mesh( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this initialize a world PG?
it may be fine to just ignore this for now in torchtitan, but, i am wondering if users want control over world group creation what would that look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc., @fduwjj are we able to disable the global PG initialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so right now we don't use split, so we can make it a fake pg. But if split is needed then we need to materialize the world PG anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should modify FLUX train.py as it's in core now.
@ruisizhang123 let's adapt SimpleFSDP after this PR is merged.
oh it seems being fixed in #1959
|
|
||
| return mesh | ||
| if self._meshes[dim].size() == 1: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not in torchtitan, in internal
| self._world_mesh = init_device_mesh( | ||
| device_type, (self.world_size,), mesh_dim_names=("world",) | ||
| ) | ||
| dataloading_mesh = unflatten_mesh( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what will happen if self.pp * batch * self.cp * self.tp != world_size? Will the _unflatten() fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will fail
20910ef to
a67e87a
Compare
99c46dc to
d6eae58
Compare
| # reductions are performed during backward. | ||
| routed_input = DTensor.from_local( | ||
| routed_input, device_mesh["tp"], (Replicate(),) | ||
| routed_input, device_mesh["etp"], (Replicate(),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be "tp".
| # IF PP is also used, this seed is unique per PP rank. | ||
| if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None: | ||
| torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) | ||
| # TODO: remove the need of duplicate_seed_meshes once torch.distributed.tensor._random.manual_seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this on your list? @wconstab
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not quite sure what 'duplicate_seed_meshes' is supposed to mean. anyway, the plan is to change manual_seed from taking a mesh to taking a device. the device (id) is the only thing the DTensor OffsetBasedRNGTracker needs from the mesh, it doesn't use the mesh for other purposes. It would look worse but it would be functionally fine to just pass 'world_mesh' to manual_seed for all cases, since, all meshes on one process share the same device. As for making the change, i'm happy to do it, but @fegin said he was going to do it so I did not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking back on this, probably what happened is that I wrote that duplicate_seed_mesh code because I didn't understand manual_seed api at the time, and i assumed that it was important to pass the correct mesh into each api call. I think we just need to keep track of the distinct_mesh_dims part and can simplify this code now.
| assert hasattr(transformer_block, "moe") | ||
| if ( | ||
| dp_mod_ep_mesh.size() * parallel_dims.ep | ||
| edp_mesh.size() * parallel_dims.ep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic seems wrong before this PR.
Here it should be efsdp_mesh because we only do sharding on efsdp not dp_replicate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar for the FSDP2 application in llama4
| dp_mesh_dim_names = ["dp_replicate", "efsdp"] | ||
| else: | ||
| dp_mesh_dim_names = ["efsdp"] | ||
| edp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edp sounds the right name here
| "dp_replicate_fsdp": hsdp_mesh, | ||
| "dp_replicate_efsdp": ehsdp_mesh, | ||
| "ep_etp": ep_etp_mesh, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems only "ep_etp" is used explicitly, the other two are not. I think we should be consistent -- e.g. we can disallow 2D slicing using get_mesh. Every time we use "dp_replicate_fsdp", we ask user to send in "hsdp", instead of ["dp_replicate", "fsdp"]. This aligns with the requirement of predefining everything in parallel_dims.py.
My guess of the motivation of erroring out if it's not pre-created n-D sliced mesh is because
- we didn't keep references to all the global meshes (e.g.
dense_mesh,sparse_mesh) which we are going to slice submeshes from, so without concatenate we don't know where to slice from. I think this is workaroundable by keeping references to all global meshes. - I had the comment that
concatenateis too powerful.
I think a concern of the current approach is that user may not be able to extend to other fancy nD submesh use cases without modifying parallel_dims.py, which they probably don't need to in most cases. If they are developing outside torchtitan, they may use concatenate anyway.
In short, I think
- if we are going with the pre-defining everything approach, we can ban
get_meshwith list inputs. - if we are going with the flexible approach, we need to keep references to global meshes to look up from.
Either is fine to me.
|
|
||
| # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP | ||
| dp_mod_ep_mesh_dim_names = [] | ||
| dp_mod_ep_mesh = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should change the name of this parameter, to edp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also the code here seems wrong https://github.com/pytorch/torchtitan/pull/1660/files#diff-2656e3a28f6b9141967a7e6ce9552879b330db03043333302081e9b8800a6a75R329
as it didn't consider dp_replicate in edp.
| if parallel_dims.ep_enabled | ||
| else None | ||
| ), | ||
| dp_mod_ep_mesh=dp_mod_ep_mesh, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here, and all other occurrences.
|
|
||
| return mesh | ||
| if self._meshes[dim].size() == 1: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
six
Summary
This PR utilizes the latest APIs provided by DeviceMesh to simplify the creation of all different meshes.
The design philosophy is as follow: