Skip to content

Ensure outer aliasing on DTensor matches inner aliasing #158954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions torch/distributed/tensor/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
from torch.utils._python_dispatch import return_and_correct_aliasing


try:
Expand Down Expand Up @@ -164,7 +165,8 @@ def dispatch(
assert output_sharding is not None, "output sharding should not be None"

mesh = op_info.compute_mesh
if mesh.get_coordinate() is not None:
participating = mesh.get_coordinate() is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this false? Is this due to h even sharding leaving one rank 'empty'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A DTensor is associated with a device mesh. The device mesh does NOT necessarily have to cover all of the GPUs that are actually known. We are still SPMD over everything, including things that are not in the mesh. Then those are not participating.

if participating:
# computation that happens in the current rank of the mesh, normal case
if output_sharding.needs_redistribute:
# If sharding propagation decision needs redistribute, perform redistribute
Expand Down Expand Up @@ -299,7 +301,11 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
assert len(out_dts) >= 1, "out variant should have at least one out arg"
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
else:
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
if participating and op_info.schema.is_view_op():
return return_and_correct_aliasing(op_call, args, kwargs, ret)
else:
return ret

@staticmethod
def redistribute_local_args(
Expand Down
6 changes: 6 additions & 0 deletions torch/distributed/tensor/_op_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,12 @@ def is_out_variant_op(self) -> bool:
# be entirely correct, but it's good enough for now.
return "out" in self.op._schema.overload_name

def is_view_op(self) -> bool:
return any(
a.alias_info is not None and not a.alias_info.is_write
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does 'not is_write' have to do with it being a view? would a modified view not set this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm being a little defensive here. Mutating ops have a write alias like mul_(Tensor(a!) self) -> Tensor(a!) self. The logic I'm applying here would work in this case, but it is not necessary because these have already been handled by DTensor manually.

for a in self.op._schema.arguments
)

def __hash__(self) -> int:
# Only hash args and kwargs that op indicates to hash
if not self.schema_info:
Expand Down
Loading