Skip to content

[DeviceMesh] Add _unflatten_ api for device mesh to support better UX for some use cases like EP and replicate #159482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: gh/fduwjj/175/base
Choose a base branch
from

Conversation

fduwjj
Copy link
Contributor

@fduwjj fduwjj commented Jul 30, 2025

Stack from ghstack (oldest at bottom):

After some initial feedback on the implementation of _split, we realize that we can first implement _unflatten for urgent use cases ask for now. And we will do more refactoring and iterations based on the discussions from this PR and this RFC: #159013. We will also ensure that all changes won't cause regression to DTensor's CPU overhead as well.

This PR is trying to:

  1. For unflatten, we don't support flatten a unflattened mesh which might not be necessary because by the time when user decide to flatten a unflatten, users might essentially redo the unflatten operations which could make bookeeping complicated to handle and we don't see that use cases for now. (We throw a NotImplementError for now)
  2. We need some extra book-keeping for unflatten api, what we want to keep track is which sub-mesh contains the unflattened dim_name so that users can slice these dim_names from root mesh as well. And we need to swap the mesh to slice when slicing from root mesh. Also to make sure we don't unflatten same dim_name into different sizes, we need to keep the total accumulated numel in the root for that dim_name as well.
  3. We want to reuse PGs already created for the same dim_name. For the case when a different dim_name happens to have different shapes, we will create new PG because with a different name, users might want to use that dimension for different purposes, so we'd better not to reuse. (This assumption can be changed, so I am open to suggestions)
  4. Added unit test to two situation: 1. we directly do unflatten on one 2D device mesh. 2. we first create a dummy 1D device mesh and then split into two 3D device mesh.

cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 30, 2025
fduwjj added a commit that referenced this pull request Jul 30, 2025
ghstack-source-id: f9eaf2d
Pull Request resolved: #159482
Copy link

pytorch-bot bot commented Jul 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159482

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (3 Unrelated Failures)

As of commit a87ee60 with merge base aeb5321 (image):

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@fduwjj fduwjj marked this pull request as draft July 30, 2025 15:41
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Jul 31, 2025
ghstack-source-id: fafd9a6
Pull Request resolved: #159482
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 1, 2025
ghstack-source-id: 04b58f7
Pull Request resolved: #159482
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 1, 2025
ghstack-source-id: a6e95eb
Pull Request resolved: #159482
@fduwjj fduwjj changed the title [WIP][DeviceMesh] Add _split api for device mesh [WIP][DeviceMesh] Add _split api for device mesh Aug 2, 2025
We decide to implement split based on our initial discussion in #159013. This PR is trying to:
1. Expose `_init_backend` in init_device_mesh so that users can create a device mesh without any backend as a default world.
2. For split, we don't support flatten a split mesh which might not be necessary since if user decide to split, users might start to split from a flattened dim already. (There are definitely corner cases, so I throw a NotImplementError)
3. We need some book-keeping for split api, we thing we want keep track is what sub-mesh contains the split dim_name so that users can slice these dim_names from root mesh as well. And we need to swap the mesh to slice when slicing from root mesh. Also to make sure we don't split same dim_name into different sizes, we need to keep the total accumulated numel for that dim_name as well.
4. We want to reuse PGs already created for the same dim_name. For different dim_name happens to have different shapes, we will create new PG because with a different name, users might want to use that dimension for different purposes, so we'd better not to reuse. (This assumption can be changed, so I am open to suggestions)
5. Added unit test to two situation: 1. we directly do split on one 2D device mesh. 2. we first create a dummy 1D device mesh and then split into two 3D device mesh.

There are definitely rough edges, so I just want to send out this PR first and gather early feedback to see if this is a reasonable direction or not.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k 

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 2, 2025
ghstack-source-id: 41d37c5
Pull Request resolved: #159482
@fduwjj fduwjj marked this pull request as ready for review August 2, 2025 00:50
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions on the API

@lw
Copy link
Contributor

lw commented Aug 4, 2025

I'm not familiar with the code thus I got a bit scared from opening the diff of this PR and didn't fully go through it.

However, I'm wondering if we can rework DeviceMesh slightly to simplify all this and handle all corner cases cleanly. Concretely, a "full" device mesh is a tensor, thus I'm tempted to reuse the same separation as tensor-vs-storage also for DeviceMeshes. What this means is that each slices/flattened/unflattened "derived" DeviceMesh is basically a view into the "root" original DeviceMesh (i.e., the storage), and such a view can likely be expressed as an offset, a (sub-)size and a stride (or something like this, I'm sure we can figure it out).

Thus we can define a single generic function that takes this offset/shape/stride combination and creates a new sub-PG for it. This function can then be used both by the _flatten codepath and by the _unflatten/_split one. Then, on the root DeviceMesh, we can keep a dictionary indexed by this offset/shape/stride tuple, so that we can detect whether a compatible PG for a certain flatten/split had already been created in the past and can be reused. And, finally, we can add a map between "names" (for dimensions) and tuples of offset/shape/stride so that we can avoid introducing ambiguous duplicate names and we can retrieve the PG for a named dimension.

Does this make sense or am I unaware of some DeviceMesh internals that would make the above impractical?

@fduwjj
Copy link
Contributor Author

fduwjj commented Aug 5, 2025

@lw Thanks for you comment, I do think what you said makes sense:

What this means is that each slices/flattened/unflattened "derived" DeviceMesh is basically a view into the "root" original DeviceMesh (i.e., the storage), and such a view can likely be expressed as an offset, a (sub-)size and a stride (or something like this, I'm sure we can figure it out)

And that is how we try to mimic our API's UX like what a tensor is doing. But one thing here is that, we have mesh_name here which makes it not as flexible as tensor because image instead of view with shape, users also need to give us mesh_name as well. I asked @wz337 aside from FSDP1, mesh_name is essentially required for all other scenarios. Today's bookeeping is doing similar things like what you said:

And, finally, we can add a map between "names" (for dimensions) and tuples of offset/shape/stride so that we can avoid introducing ambiguous duplicate names and we can retrieve the PG for a named dimension.

The one extra dimension is dim_name, we might consider that as well.. I need to think more on this proposal.

We decide to implement split based on our initial discussion in #159013. This PR is trying to:
1. Expose `_init_backend` in init_device_mesh so that users can create a device mesh without any backend as a default world.
2. For split, we don't support flatten a split mesh which might not be necessary since if user decide to split, users might start to split from a flattened dim already. (There are definitely corner cases, so I throw a NotImplementError)
3. We need some book-keeping for split api, we thing we want keep track is what sub-mesh contains the split dim_name so that users can slice these dim_names from root mesh as well. And we need to swap the mesh to slice when slicing from root mesh. Also to make sure we don't split same dim_name into different sizes, we need to keep the total accumulated numel for that dim_name as well.
4. We want to reuse PGs already created for the same dim_name. For different dim_name happens to have different shapes, we will create new PG because with a different name, users might want to use that dimension for different purposes, so we'd better not to reuse. (This assumption can be changed, so I am open to suggestions)
5. Added unit test to two situation: 1. we directly do split on one 2D device mesh. 2. we first create a dummy 1D device mesh and then split into two 3D device mesh.

There are definitely rough edges, so I just want to send out this PR first and gather early feedback to see if this is a reasonable direction or not.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: 2b89c95
Pull Request resolved: #159482
@fduwjj
Copy link
Contributor Author

fduwjj commented Aug 5, 2025

Discussed with @lw offline, we might need to do more refactors down the road but this PR is aiming at unblocking some urgent use cases from EP, so as long as we make UX and API signature correct, we will perform clean up later. And this is going to be multiple steps and we will make sure no regression to DTensor's overhead as well.

@fduwjj fduwjj changed the title [WIP][DeviceMesh] Add _split api for device mesh [WIP][DeviceMesh] Add _unflatten_ api for device mesh Aug 5, 2025
@fduwjj
Copy link
Contributor Author

fduwjj commented Aug 5, 2025

Updated the PR based on the feedback from @wanchaol and I will do more cleanups to the logic of bookeepings part so it look cleaner. All UT passed. We also need to rebase on top of #159371.

After some initial feedback on the implementation of `_split`, we realize that we can first implement `_unflatten` for urgent use cases ask for now. And we will do more refactoring and iterations based on the discussions from this PR and this RFC: #159013. We will also ensure that all changes won't cause regression to DTensor's CPU overhead as well.


This PR is trying to:
1. For unflatten, we don't support flatten a unflattened mesh which might not be necessary because by the time when user decide to flatten a unflatten, users might essentially redo the unflatten operations which could make bookeeping complicated to handle and we don't see that use cases for now. (We throw a NotImplementError for now)
3. We need some extra book-keeping for unflatten api, what we want to keep track is which sub-mesh contains the unflattened `dim_name` so that users can slice these dim_names from root mesh as well. And we need to swap the mesh to slice when slicing from root mesh. Also to make sure we don't unflatten same `dim_name` into different sizes, we need to keep the total accumulated numel in the root for that dim_name as well.
4. We want to reuse PGs already created for the same `dim_name`. For the case when a different `dim_name` happens to have different shapes, we will create new PG because with a different name, users might want to use that dimension for different purposes, so we'd better not to reuse. (This assumption can be changed, so I am open to suggestions)
5. Added unit test to two situation: 1. we directly do unflatten on one 2D device mesh. 2. we first create a dummy 1D device mesh and then split into two 3D device mesh.

cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta 

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: ab05390
Pull Request resolved: #159482
@fduwjj fduwjj changed the title [WIP][DeviceMesh] Add _unflatten_ api for device mesh [DeviceMesh] Add _unflatten_ api for device mesh Aug 5, 2025
@fduwjj fduwjj changed the title [DeviceMesh] Add _unflatten_ api for device mesh [DeviceMesh] Add _unflatten_ api for device mesh to support better UX for some use cases like EP and replicate Aug 5, 2025
@fduwjj fduwjj requested a review from wanchaol August 5, 2025 22:50
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to review this PR in details but the unflatten logic feels pretty complicated to understand. I think we should try to simplify the implementation, i.e.

  • Make fewest assumptions, i.e. let's not get into the business of "same unflatten shape + different unflatten mesh dim names", unless there's a clear use case for it. I think simply error out in that case make more sense.
  • For unflatten I don't quite feel the "implicit mesh dim appear in the root mesh" is a good UX, it's quite surprising behavior IMO and not quite necessary. So I suggest to not make it part of feature of unflatten and later consider deprecate this behavior in flatten.

I think for this PR, let's only make the essential feature to work, this would help simplify the implementation.

self.assertEqual(
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"]
)
self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I actually feel it's quite weird to let unflatten (or even the existing flatten behavior) to implicitly add a flattened/unflattend dim to the original mesh, it was for convenience purpose but I think it's sth quite weird (i.e. the original mesh is not having that dimension but can somehow access it). I think we probably don't want to make this to be an feature in unflatten, and maybe in the future deprecate this semantic in flatten. User can use the returned mesh directly, but not from the original mesh

Copy link
Contributor Author

@fduwjj fduwjj Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually goes against the UX you have agreed on above, where we are trying to slicing the unflattened dim_name from the root mesh.

# Create a dummy 1D device mesh first, so no PG will be created under the hood. This is important for large-scale because no global PG will be used in that case.
device_mesh = init_device_mesh(
    'cuda', (32,), mesh_dim_names=["world"],
)
device_mesh._unflatten(0, (4,2,4), mesh_dim_names=["dp", "cp", "tp"])
device_mesh._unflatten(0, (2,8,2), mesh_dim_names=["dp_ep", "ep", "tp_ep"])
device_mesh["cp", "tp"]["tp"]

because instead of letting users to use same mesh we are actually ask user to do the bookkeeping themsevles. So users will need to keep track of flattened or unflattened submesh themselves.. which is really does not improve UX for EP here. And if you don't mind me asking, when the logic of _flatten is designed, reviewed and approved, it looks like you are also part of it as well. So what was the assumption then and what has changed now to make you change your mind? Or more broadly how did we make the call in the past? I have not heard any negative feedback regarding the slicing of flattened mesh_dim from the root mesh.

The reason I am asking here is that per @ezyang 's suggestion, we are trying to understand your design principle and mental models and so that we can keep the same design philosophy as much as possible. This does not say this bookkeep logic is per or against that design principle, but I find it confusing why we decide to make a shift there. Furthermore, IIUC, here is what I heard or observed so far:

  1. We are trying to mimic all the behaviors of either torch.tensor and c10d as much as possible. Most of mesh API is trying to have same name and similar API signature as torch.tensor. This is because DeviceMesh is essentially a device representation with a mesh tensor. This makes sense, but it somehow also justify the slicing of flattened/unflattened dim_names since even with flatten and unflatten they are still one mesh(tensor)?

  2. We also borrowed concept from XLA/JAX's devicemesh as well for its design. And IIUC, under the hood it is an N-D array so that aligns with why we mimic tensor behavior in point 1 as well.

  3. We believe the abstraction of the Devicemesh inside Monarch and PyTorch are kind of two different levels, which is mentioned here: [RFC] Support slicing of submesh and more flexible operations to one device mesh such as reshape/split, etc. #159013 (comment). So we won't try to converge it with Monarch as much as possible.

Anything more you want to add? @wanchaol

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unflatten API itself is the correct UX I agreed on, but the way to access the unflattened device meshes from the root mesh is not, so basically the last line is the discrepency: device_mesh["cp", "tp"]["tp"], instead user should use the unflattened result to access the device mesh I think.

The reason I go against this behavior is that accessing from the root mesh is quite an implicit behavior, this is not really my design principle, but the Python/PyTorch design principle: Explicit is better than implicit!

To answer your question, when the flatten API was designed, I was also NOT a fan of this behavior, but given flatten only add one dimension to the root mesh, it did not quite contaminate this principle. Concerns about this was also the main reason I wanted to keep flatten private first. But now with the design of this unflatten API, I think the contamination become much worse. Why? Ideally user shouldn't aware of the root mesh concept as that is mainly an implementation detail, so the "root mesh" and the unflattened mesh should simply both be device meshes that represents a n-d array of devices. But now with the unflatten + root mesh behavior, a 1-D root device mesh suddenly have a bunch of additional dimensions! Taking the example you gave:

  • device_mesh is a 1-D device mesh, but suddenly after the unflatten, it got 6 more dimensions! 6-more dimensions on a 1-D device mesh does not make sense as an 1-D device array abstraction
  • What's worse is that it might not only have 6, it could have 8, 10 or even more dimensions because of calling the unflatten twice or more!

The implicit addition of device mesh dimensions make the abstraction not clean, and IMO we should avoid this implicit behavior.

@@ -77,6 +77,10 @@ def __init__(self) -> None:
self.flatten_name_to_root_dims: dict[
DeviceMesh, dict[str, tuple[int, ...]]
] = {}
self.root_to_unflatten_mesh: dict[
DeviceMesh, dict[str, set[DeviceMesh]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you only bookkeep the mesh_dim instead of the actual device mesh?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we only bookkeep the mesh_dim, it is not helpful because this is unflatten which means we are creating new dims and it is not like flatten which we kind of consolidate existing dims inside root mesh. that's why we need to swap the mesh for slicing and has this map. But I am open to better suggestions if any.

# We need to check whether one dim_name is used in unflatten into more than one shape/dim.
# For example, if we have already unflatten into a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp, "cp", "tp")
# and we want to unflatten into a 2D mesh with mesh_shape (4, 2) mesh_dim_names ("dp", "cp"), this is wrong.
# But unflatten into another 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp, "ep", "ep_tp") is legit.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm if unflatten to the same shape, I guess it would just return the "cached" unflattened results? Could you explain why it needs all those checks here and only reuse the dp dimension pg, instead of reusing all dimension pgs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I guess we need to define the behavior here.

  1. If unflatten into same shape and same dim_names, yes we just return the cached unflatten results. This is clear.
  2. If unflatten into same shape but different dim_names (sometimes the value of shape is dynamic so it can happen incidentally), my current thinking is that we should treat it as a new PG for the dim where dim_name is different because users might need a different stream for that dim? (Same PG means same stream). But like I mentioned in the comment, I am open to discussions for this as well, no strong opinion on this. This is also where I am curious about your comment as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If unflatten into same shape but different dim_names (sometimes the value of shape is dynamic so it can happen incidentally), my current thinking is that we should treat it as a new PG for the dim where dim_name is different because users might need a different stream for that dim?

I think if the set of ranks are the same, they should be the same communicator, maybe they could accidentally have the same shape because of dynamism, but IMO if they are the same shape, they should just reuse the same communicators, if user really want a different stream/group that have exact same ranks, I guess they should do it by hand to make it explicit. I think this is also the behavior of process group API.

# Reuse the existing pg
unchanged_dim_numel *= mesh_sizes[idx]
else:
# This dim name has never been unflatten into, and we need to create a new PG.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking through the logic of unflatten, I think one thing I observed is that the logic is quite complicated, I wonder if we need to cut the assumptions or features of this PR, and make it as few as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree the logic is complicated but we definitely can try to simply it. But before we do that, we need to at least ensure and agree on what UX we are providing. The current logic/check is to support:

  1. unflatten from a root mesh and we can do slicing from a submesh/root mesh
  2. unflatten from a submesh and we can do slicing from a submesh/root mesh

we can only allow the unflatten from root mesh but this will simplify the logic a bit but not too much. Unless we want to further cut the UX.

@wanchaol
Copy link
Collaborator

wanchaol commented Aug 6, 2025

However, I'm wondering if we can rework DeviceMesh slightly to simplify all this and handle all corner cases cleanly. Concretely, a "full" device mesh is a tensor, thus I'm tempted to reuse the same separation as tensor-vs-storage also for DeviceMeshes. What this means is that each slices/flattened/unflattened "derived" DeviceMesh is basically a view into the "root" original DeviceMesh (i.e., the storage), and such a view can likely be expressed as an offset, a (sub-)size and a stride (or something like this, I'm sure we can figure it out).

Thus we can define a single generic function that takes this offset/shape/stride combination and creates a new sub-PG for it. This function can then be used both by the _flatten codepath and by the _unflatten/_split one. Then, on the root DeviceMesh, we can keep a dictionary indexed by this offset/shape/stride tuple, so that we can detect whether a compatible PG for a certain flatten/split had already been created in the past and can be reused. And, finally, we can add a map between "names" (for dimensions) and tuples of offset/shape/stride so that we can avoid introducing ambiguous duplicate names and we can retrieve the PG for a named dimension.

@lw I think overall this make sense to me! Probably need to get into the details to figure out how exactly the tensor <-> storage mapping would work, but yeah as long as it could simplify the existing implementation, I'm all for it :)

@fduwjj
Copy link
Contributor Author

fduwjj commented Aug 6, 2025

Thanks for your comment and input again, and I really appreciate it @wanchaol

To your comment, I do have some follow-up questions here:

Make fewest assumptions, i.e. let's not get into the business of "same unflatten shape + different unflatten mesh dim names", unless there's a clear use case for it. I think simply error out in that case make more sense.

However, in the EP's use case, because the mesh shape is dynamic, we will run into this cases and this the explicit cases which @tianyu-l gave to me as well.

For unflatten I don't quite feel the "implicit mesh dim appear in the root mesh" is a good UX, it's quite surprising behavior IMO and not quite necessary. So I suggest to not make it part of feature of unflatten and later consider deprecate this behavior in flatten.

Can you kindly answer my question in #159482 (comment)? Because sounds like we have not even agreed on final UX yet, so I want to make sure we are on the same page from the first beginning. Also if you can share why we have such "surprising behavior" in the first beginning for flatten that will also be helpful as well.

I think for this PR, let's only make the essential feature to work, this would help simplify the implementation.

I totally agree that we shouldn't make things complicated and this is something I try to avoid as well. However, we do need to support the UX we promised, unless we say no to these promises as well. I hope this could sound make sense and I am looking forward to your comment as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants