-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[DeviceMesh] Add _unflatten_
api for device mesh to support better UX for some use cases like EP and replicate
#159482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fduwjj/175/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159482
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (3 Unrelated Failures)As of commit a87ee60 with merge base aeb5321 ( UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
_split
api for device mesh
We decide to implement split based on our initial discussion in #159013. This PR is trying to: 1. Expose `_init_backend` in init_device_mesh so that users can create a device mesh without any backend as a default world. 2. For split, we don't support flatten a split mesh which might not be necessary since if user decide to split, users might start to split from a flattened dim already. (There are definitely corner cases, so I throw a NotImplementError) 3. We need some book-keeping for split api, we thing we want keep track is what sub-mesh contains the split dim_name so that users can slice these dim_names from root mesh as well. And we need to swap the mesh to slice when slicing from root mesh. Also to make sure we don't split same dim_name into different sizes, we need to keep the total accumulated numel for that dim_name as well. 4. We want to reuse PGs already created for the same dim_name. For different dim_name happens to have different shapes, we will create new PG because with a different name, users might want to use that dimension for different purposes, so we'd better not to reuse. (This assumption can be changed, so I am open to suggestions) 5. Added unit test to two situation: 1. we directly do split on one 2D device mesh. 2. we first create a dummy 1D device mesh and then split into two 3D device mesh. There are definitely rough edges, so I just want to send out this PR first and gather early feedback to see if this is a reasonable direction or not. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some suggestions on the API
I'm not familiar with the code thus I got a bit scared from opening the diff of this PR and didn't fully go through it. However, I'm wondering if we can rework DeviceMesh slightly to simplify all this and handle all corner cases cleanly. Concretely, a "full" device mesh is a tensor, thus I'm tempted to reuse the same separation as tensor-vs-storage also for DeviceMeshes. What this means is that each slices/flattened/unflattened "derived" DeviceMesh is basically a view into the "root" original DeviceMesh (i.e., the storage), and such a view can likely be expressed as an offset, a (sub-)size and a stride (or something like this, I'm sure we can figure it out). Thus we can define a single generic function that takes this offset/shape/stride combination and creates a new sub-PG for it. This function can then be used both by the Does this make sense or am I unaware of some DeviceMesh internals that would make the above impractical? |
@lw Thanks for you comment, I do think what you said makes sense:
And that is how we try to mimic our API's UX like what a tensor is doing. But one thing here is that, we have mesh_name here which makes it not as flexible as tensor because image instead of view with shape, users also need to give us mesh_name as well. I asked @wz337 aside from FSDP1, mesh_name is essentially required for all other scenarios. Today's bookeeping is doing similar things like what you said:
The one extra dimension is dim_name, we might consider that as well.. I need to think more on this proposal. |
We decide to implement split based on our initial discussion in #159013. This PR is trying to: 1. Expose `_init_backend` in init_device_mesh so that users can create a device mesh without any backend as a default world. 2. For split, we don't support flatten a split mesh which might not be necessary since if user decide to split, users might start to split from a flattened dim already. (There are definitely corner cases, so I throw a NotImplementError) 3. We need some book-keeping for split api, we thing we want keep track is what sub-mesh contains the split dim_name so that users can slice these dim_names from root mesh as well. And we need to swap the mesh to slice when slicing from root mesh. Also to make sure we don't split same dim_name into different sizes, we need to keep the total accumulated numel for that dim_name as well. 4. We want to reuse PGs already created for the same dim_name. For different dim_name happens to have different shapes, we will create new PG because with a different name, users might want to use that dimension for different purposes, so we'd better not to reuse. (This assumption can be changed, so I am open to suggestions) 5. Added unit test to two situation: 1. we directly do split on one 2D device mesh. 2. we first create a dummy 1D device mesh and then split into two 3D device mesh. There are definitely rough edges, so I just want to send out this PR first and gather early feedback to see if this is a reasonable direction or not. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
Discussed with @lw offline, we might need to do more refactors down the road but this PR is aiming at unblocking some urgent use cases from EP, so as long as we make UX and API signature correct, we will perform clean up later. And this is going to be multiple steps and we will make sure no regression to DTensor's overhead as well. |
_split
api for device mesh_unflatten_
api for device mesh
After some initial feedback on the implementation of `_split`, we realize that we can first implement `_unflatten` for urgent use cases ask for now. And we will do more refactoring and iterations based on the discussions from this PR and this RFC: #159013. We will also ensure that all changes won't cause regression to DTensor's CPU overhead as well. This PR is trying to: 1. For unflatten, we don't support flatten a unflattened mesh which might not be necessary because by the time when user decide to flatten a unflatten, users might essentially redo the unflatten operations which could make bookeeping complicated to handle and we don't see that use cases for now. (We throw a NotImplementError for now) 3. We need some extra book-keeping for unflatten api, what we want to keep track is which sub-mesh contains the unflattened `dim_name` so that users can slice these dim_names from root mesh as well. And we need to swap the mesh to slice when slicing from root mesh. Also to make sure we don't unflatten same `dim_name` into different sizes, we need to keep the total accumulated numel in the root for that dim_name as well. 4. We want to reuse PGs already created for the same `dim_name`. For the case when a different `dim_name` happens to have different shapes, we will create new PG because with a different name, users might want to use that dimension for different purposes, so we'd better not to reuse. (This assumption can be changed, so I am open to suggestions) 5. Added unit test to two situation: 1. we directly do unflatten on one 2D device mesh. 2. we first create a dummy 1D device mesh and then split into two 3D device mesh. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta [ghstack-poisoned]
_unflatten_
api for device mesh_unflatten_
api for device mesh
_unflatten_
api for device mesh_unflatten_
api for device mesh to support better UX for some use cases like EP and replicate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to review this PR in details but the unflatten logic feels pretty complicated to understand. I think we should try to simplify the implementation, i.e.
- Make fewest assumptions, i.e. let's not get into the business of "same unflatten shape + different unflatten mesh dim names", unless there's a clear use case for it. I think simply error out in that case make more sense.
- For unflatten I don't quite feel the "implicit mesh dim appear in the root mesh" is a good UX, it's quite surprising behavior IMO and not quite necessary. So I suggest to not make it part of feature of
unflatten
and later consider deprecate this behavior inflatten
.
I think for this PR, let's only make the essential feature to work, this would help simplify the implementation.
self.assertEqual( | ||
unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"] | ||
) | ||
self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I actually feel it's quite weird to let unflatten
(or even the existing flatten
behavior) to implicitly add a flattened/unflattend dim to the original mesh, it was for convenience purpose but I think it's sth quite weird (i.e. the original mesh is not having that dimension but can somehow access it). I think we probably don't want to make this to be an feature in unflatten
, and maybe in the future deprecate this semantic in flatten
. User can use the returned mesh directly, but not from the original mesh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually goes against the UX you have agreed on above, where we are trying to slicing the unflattened dim_name
from the root mesh.
# Create a dummy 1D device mesh first, so no PG will be created under the hood. This is important for large-scale because no global PG will be used in that case.
device_mesh = init_device_mesh(
'cuda', (32,), mesh_dim_names=["world"],
)
device_mesh._unflatten(0, (4,2,4), mesh_dim_names=["dp", "cp", "tp"])
device_mesh._unflatten(0, (2,8,2), mesh_dim_names=["dp_ep", "ep", "tp_ep"])
device_mesh["cp", "tp"]["tp"]
because instead of letting users to use same mesh we are actually ask user to do the bookkeeping themsevles. So users will need to keep track of flattened or unflattened submesh themselves.. which is really does not improve UX for EP here. And if you don't mind me asking, when the logic of _flatten
is designed, reviewed and approved, it looks like you are also part of it as well. So what was the assumption then and what has changed now to make you change your mind? Or more broadly how did we make the call in the past? I have not heard any negative feedback regarding the slicing of flattened mesh_dim
from the root mesh.
The reason I am asking here is that per @ezyang 's suggestion, we are trying to understand your design principle and mental models and so that we can keep the same design philosophy as much as possible. This does not say this bookkeep logic is per or against that design principle, but I find it confusing why we decide to make a shift there. Furthermore, IIUC, here is what I heard or observed so far:
-
We are trying to mimic all the behaviors of either torch.tensor and c10d as much as possible. Most of mesh API is trying to have same name and similar API signature as torch.tensor. This is because DeviceMesh is essentially a device representation with a mesh tensor. This makes sense, but it somehow also justify the slicing of flattened/unflattened dim_names since even with flatten and unflatten they are still one mesh(tensor)?
-
We also borrowed concept from XLA/JAX's devicemesh as well for its design. And IIUC, under the hood it is an N-D array so that aligns with why we mimic tensor behavior in point 1 as well.
-
We believe the abstraction of the Devicemesh inside Monarch and PyTorch are kind of two different levels, which is mentioned here: [RFC] Support slicing of submesh and more flexible operations to one device mesh such as reshape/split, etc. #159013 (comment). So we won't try to converge it with Monarch as much as possible.
Anything more you want to add? @wanchaol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unflatten API itself is the correct UX I agreed on, but the way to access the unflattened device meshes from the root mesh is not, so basically the last line is the discrepency: device_mesh["cp", "tp"]["tp"]
, instead user should use the unflattened result to access the device mesh I think.
The reason I go against this behavior is that accessing from the root mesh is quite an implicit behavior, this is not really my design principle, but the Python/PyTorch design principle: Explicit is better than implicit!
To answer your question, when the flatten
API was designed, I was also NOT a fan of this behavior, but given flatten
only add one dimension to the root mesh, it did not quite contaminate this principle. Concerns about this was also the main reason I wanted to keep flatten
private first. But now with the design of this unflatten
API, I think the contamination become much worse. Why? Ideally user shouldn't aware of the root mesh concept as that is mainly an implementation detail, so the "root mesh" and the unflattened mesh should simply both be device meshes that represents a n-d array of devices. But now with the unflatten
+ root mesh behavior, a 1-D root device mesh suddenly have a bunch of additional dimensions! Taking the example you gave:
device_mesh
is a 1-D device mesh, but suddenly after theunflatten
, it got 6 more dimensions! 6-more dimensions on a 1-D device mesh does not make sense as an 1-D device array abstraction- What's worse is that it might not only have 6, it could have 8, 10 or even more dimensions because of calling the
unflatten
twice or more!
The implicit addition of device mesh dimensions make the abstraction not clean, and IMO we should avoid this implicit behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for sharing your thought finally. I do agree that explicit is better than implicit so it makes sense that ideally we can have a more functional-like behavior. Since we already have the principle, we should have rejected the contamination earlier and I was under the impression that we did have a long discussion around flatten
api as well. But somehow we still end up having a somehow contaminated api which is kind of suboptimal in the first place. OK let me also ask users like @tianyu-l 's input on this as well.
@@ -77,6 +77,10 @@ def __init__(self) -> None: | |||
self.flatten_name_to_root_dims: dict[ | |||
DeviceMesh, dict[str, tuple[int, ...]] | |||
] = {} | |||
self.root_to_unflatten_mesh: dict[ | |||
DeviceMesh, dict[str, set[DeviceMesh]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you only bookkeep the mesh_dim instead of the actual device mesh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we only bookkeep the mesh_dim, it is not helpful because this is unflatten which means we are creating new dims and it is not like flatten which we kind of consolidate existing dims inside root mesh. that's why we need to swap the mesh for slicing and has this map. But I am open to better suggestions if any.
# We need to check whether one dim_name is used in unflatten into more than one shape/dim. | ||
# For example, if we have already unflatten into a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp, "cp", "tp") | ||
# and we want to unflatten into a 2D mesh with mesh_shape (4, 2) mesh_dim_names ("dp", "cp"), this is wrong. | ||
# But unflatten into another 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp, "ep", "ep_tp") is legit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm if unflatten to the same shape, I guess it would just return the "cached" unflattened results? Could you explain why it needs all those checks here and only reuse the dp dimension pg, instead of reusing all dimension pgs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I guess we need to define the behavior here.
- If unflatten into same shape and same dim_names, yes we just return the cached unflatten results. This is clear.
- If unflatten into same shape but different dim_names (sometimes the value of shape is dynamic so it can happen incidentally), my current thinking is that we should treat it as a new PG for the dim where dim_name is different because users might need a different stream for that dim? (Same PG means same stream). But like I mentioned in the comment, I am open to discussions for this as well, no strong opinion on this. This is also where I am curious about your comment as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If unflatten into same shape but different dim_names (sometimes the value of shape is dynamic so it can happen incidentally), my current thinking is that we should treat it as a new PG for the dim where dim_name is different because users might need a different stream for that dim?
I think if the set of ranks are the same, they should be the same communicator, maybe they could accidentally have the same shape because of dynamism, but IMO if they are the same shape, they should just reuse the same communicators, if user really want a different stream/group that have exact same ranks, I guess they should do it by hand to make it explicit. I think this is also the behavior of process group API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, based on your comment above, since we like to hold the principle of being explicit, I agree we can just reuse it to make the behavior simple and predictable.
# Reuse the existing pg | ||
unchanged_dim_numel *= mesh_sizes[idx] | ||
else: | ||
# This dim name has never been unflatten into, and we need to create a new PG. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking through the logic of unflatten, I think one thing I observed is that the logic is quite complicated, I wonder if we need to cut the assumptions or features of this PR, and make it as few as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree the logic is complicated but we definitely can try to simply it. But before we do that, we need to at least ensure and agree on what UX we are providing. The current logic/check is to support:
- unflatten from a root mesh and we can do slicing from a submesh/root mesh
- unflatten from a submesh and we can do slicing from a submesh/root mesh
we can only allow the unflatten from root mesh but this will simplify the logic a bit but not too much. Unless we want to further cut the UX.
@lw I think overall this make sense to me! Probably need to get into the details to figure out how exactly the tensor <-> storage mapping would work, but yeah as long as it could simplify the existing implementation, I'm all for it :) |
Thanks for your comment and input again, and I really appreciate it @wanchaol To your comment, I do have some follow-up questions here:
However, in the EP's use case, because the mesh shape is dynamic, we will run into this cases and this the explicit cases which @tianyu-l gave to me as well.
Can you kindly answer my question in #159482 (comment)? Because sounds like we have not even agreed on final UX yet, so I want to make sure we are on the same page from the first beginning. Also if you can share why we have such "surprising behavior" in the first beginning for flatten that will also be helpful as well.
I totally agree that we shouldn't make things complicated and this is something I try to avoid as well. However, we do need to support the UX we promised, unless we say no to these promises as well. I hope this could sound make sense and I am looking forward to your comment as well. |
Stack from ghstack (oldest at bottom):
_unflatten_
api for device mesh to support better UX for some use cases like EP and replicate #159482After some initial feedback on the implementation of
_split
, we realize that we can first implement_unflatten
for urgent use cases ask for now. And we will do more refactoring and iterations based on the discussions from this PR and this RFC: #159013. We will also ensure that all changes won't cause regression to DTensor's CPU overhead as well.This PR is trying to:
dim_name
so that users can slice these dim_names from root mesh as well. And we need to swap the mesh to slice when slicing from root mesh. Also to make sure we don't unflatten samedim_name
into different sizes, we need to keep the total accumulated numel in the root for that dim_name as well.dim_name
. For the case when a differentdim_name
happens to have different shapes, we will create new PG because with a different name, users might want to use that dimension for different purposes, so we'd better not to reuse. (This assumption can be changed, so I am open to suggestions)cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta