diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index faa2a1ba4941..b562153ad507 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -23,6 +23,7 @@ ) from torch.distributed.tensor._utils import try_find_mesh_from_args from torch.distributed.tensor.placement_types import Partial, Placement, Replicate +from torch.utils._python_dispatch import return_and_correct_aliasing try: @@ -164,7 +165,8 @@ def dispatch( assert output_sharding is not None, "output sharding should not be None" mesh = op_info.compute_mesh - if mesh.get_coordinate() is not None: + participating = mesh.get_coordinate() is not None + if participating: # computation that happens in the current rank of the mesh, normal case if output_sharding.needs_redistribute: # If sharding propagation decision needs redistribute, perform redistribute @@ -299,7 +301,11 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: assert len(out_dts) >= 1, "out variant should have at least one out arg" return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] else: - return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + if participating and op_info.schema.is_view_op(): + return return_and_correct_aliasing(op_call, args, kwargs, ret) + else: + return ret @staticmethod def redistribute_local_args( diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index b892d8883527..b60373ea6f83 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -450,6 +450,12 @@ def is_out_variant_op(self) -> bool: # be entirely correct, but it's good enough for now. return "out" in self.op._schema.overload_name + def is_view_op(self) -> bool: + return any( + a.alias_info is not None and not a.alias_info.is_write + for a in self.op._schema.arguments + ) + def __hash__(self) -> int: # Only hash args and kwargs that op indicates to hash if not self.schema_info: