Skip to content

[DTensor] add op support: aten.squeeze_.dim #159532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: gh/XilunWu/161/base
Choose a base branch
from

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Jul 30, 2025

Stack from ghstack (oldest at bottom):

Summary
This PR enables in-place op aten.squeeze_.dim on DTensor with a change to
DTensor dispatch logic: when processing in-place operator, we should assign
output_sharding.output_spec back to the first argument. This is because
the in-place op_call on arg._local_tensor could also shift the tensor meta.

Test
pytest test/distributed/tensor/test_view_ops.py -s -k test_squeeze_

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @tianyu-l

Copy link

pytorch-bot bot commented Jul 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159532

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures, 1 Unrelated Failure

As of commit 0477c8f with merge base ddbdcdc (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

XilunWu added a commit that referenced this pull request Jul 30, 2025
ghstack-source-id: b4a560c
Pull Request resolved: #159532
@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jul 30, 2025
output_spec = output_sharding.output_spec
assert isinstance(output_spec, DTensorSpec)
assert isinstance(args[0], DTensor)
args[0]._spec = output_spec
Copy link
Contributor Author

@XilunWu XilunWu Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l pointed out that, besides DTensor._spec.tensor_meta, the subclass' metadata also needs override (but it seems not doable...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edward suggests try return_and_correct_aliasing to fix the outer tensor meta, will try.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed that #158954 fixes the outer and inner aliasing mismatch issue. cc @ezyang

@XilunWu XilunWu added topic: not user facing topic category module: dtensor distributed tensor tag labels Jul 31, 2025
x = torch.randn((1, 4), device=self.device_type)
dist_x = DTensor.from_local(x, mesh_2d, [Partial(), Shard(1)])
self._test_op_on_dtensor(
torch.ops.aten.squeeze_.dim,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check if dist_x is changed or not?

**Summary**
This PR enables in-place op `aten.squeeze_.dim` on DTensor with a change to
DTensor dispatch logic: when processing in-place operator, we should assign
`output_sharding.output_spec` back to the first argument. This is because
the in-place op_call on `arg._local_tensor` could also shift the tensor meta.

**Test**
`pytest test/distributed/tensor/test_view_ops.py -s -k  test_squeeze_`

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta tianyu-l

[ghstack-poisoned]
@XilunWu
Copy link
Contributor Author

XilunWu commented Aug 5, 2025

need to debug why "distributed/tensor/test_math_ops.py::DistMathOpsTest::test_rms_norm_bwd" is broken by this change

XilunWu added a commit that referenced this pull request Aug 5, 2025
ghstack-source-id: fe05405
Pull Request resolved: #159532
@XilunWu XilunWu marked this pull request as draft August 5, 2025 10:14
@AaronWang04
Copy link
Contributor

AaronWang04 commented Aug 7, 2025

@XilunWu this is the table dump of the forward pass. as you said yeah it is missing a collective and all the values are off :/

Maybe something with how rsqrt uses squeeze_dim? not sure rn I can investigate further if needed

Interesting part of the dump (note the partial vs replicate of rsqrt)
main branch

        **aten.mean.dim
        **aten.add_.Scalar
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Partial(avg),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **_c10d_functional.all_reduce.default
        **_c10d_functional.wait_tensor.default
        **aten.add_.Scalar
        **aten.rsqrt.default
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Partial(avg),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **_c10d_functional.all_reduce.default
        **_c10d_functional.wait_tensor.default
        **aten.rsqrt.default

this PR

        **aten.mean.dim
        **aten.add_.Scalar
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Partial(avg),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **_c10d_functional.all_reduce.default
        **_c10d_functional.wait_tensor.default
        **aten.add_.Scalar
        **aten.rsqrt.default
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.rsqrt.default

Full dump of forward pass
main branch

RMSNorm
    *module type: class 'torch.nn.modules.normalization.RMSNorm'
    *Parameter List
     *weight: (Replicate(),)
      FORWARD PASS
        *c10d_functional.all_reduce: 2
        **aten.view.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.view.default
        **aten.detach.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default
        **aten.detach.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default
        **aten.detach.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default
        **aten.detach.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default
        **aten.view.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.view.default
        **aten._fused_rms_norm.default
          shape: [torch.Size([20, 5, 10]), torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),), (Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.to.dtype
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.to.dtype
        **aten.pow.Tensor_Scalar
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.pow.Tensor_Scalar
        **aten.mean.dim
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.mean.dim
        **aten.add_.Scalar
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Partial(avg),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **_c10d_functional.all_reduce.default
        **_c10d_functional.wait_tensor.default
        **aten.add_.Scalar
        **aten.rsqrt.default
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Partial(avg),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **_c10d_functional.all_reduce.default
        **_c10d_functional.wait_tensor.default
        **aten.rsqrt.default
        **aten.mul.Tensor
          shape: [torch.Size([20, 5, 10]), torch.Size([1, 1, 1])]
          sharding: [(Shard(dim=0),), (Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.mul.Tensor
        **aten.mul.Tensor
          shape: [torch.Size([20, 5, 10]), torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),), (Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.chunk.default
        **aten.clone.default
        **aten.mul.Tensor
        **aten.type_as.default
          shape: [torch.Size([20, 5, 10]), torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),), (Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.to.dtype_layout
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default

this PR

RMSNorm
    *module type: class 'torch.nn.modules.normalization.RMSNorm'
    *Parameter List
     *weight: (Replicate(),)
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        **aten.view.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.view.default
        **aten.detach.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default
        **aten.detach.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default
        **aten.detach.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default
        **aten.detach.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default
        **aten.view.default
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.view.default
        **aten._fused_rms_norm.default
          shape: [torch.Size([20, 5, 10]), torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),), (Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.to.dtype
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.to.dtype
        **aten.pow.Tensor_Scalar
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.pow.Tensor_Scalar
        **aten.mean.dim
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.mean.dim
        **aten.add_.Scalar
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Partial(avg),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **_c10d_functional.all_reduce.default
        **_c10d_functional.wait_tensor.default
        **aten.add_.Scalar
        **aten.rsqrt.default
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.rsqrt.default
        **aten.mul.Tensor
          shape: [torch.Size([20, 5, 10]), torch.Size([1, 1, 1])]
          sharding: [(Shard(dim=0),), (Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.mul.Tensor
        **aten.mul.Tensor
          shape: [torch.Size([20, 5, 10]), torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),), (Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.chunk.default
        **aten.clone.default
        **aten.mul.Tensor
        **aten.type_as.default
          shape: [torch.Size([20, 5, 10]), torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),), (Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.to.dtype_layout
          shape: [torch.Size([20, 5, 10])]
          sharding: [(Shard(dim=0),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
          shape: [torch.Size([1, 1, 1])]
          sharding: [(Replicate(),)]
          device mesh: DeviceMesh((4,), device: 'cuda', stride: (1,))
        **aten.detach.default
        **aten.detach.default

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants