Skip to content

Work around MPSGraph issue in backward pass of nn.ReplicationPad1d/2d #152094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

xwu-498
Copy link

@xwu-498 xwu-498 commented Apr 24, 2025

Fixes #135447.

When the 3rd from last dimension is 2^16 or greater, MPSGraph returns 0 for padgradient.
To work around this, we break the problematic dimension into chunks with chunk size being
no greater than 2^16 - 1.

Test case for nn.ReplicationPad1d:

    shape = [65739, 2, 4]
    x_cpu = torch.randn(shape, device='cpu', requires_grad=True)
    x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True)
    model = torch.nn.ReplicationPad1d((1, 1))

    out_cpu = model(x_cpu)
    out_mps = model(x_mps)

    # backward
    g_cpu = torch.randn_like(out_cpu)
    g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False)
    out_cpu.backward(g_cpu)
    out_mps.backward(g_mps)

    print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }")

    # Expected Output:
    # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0)

Test case for nn.ReplicationPad2d,

    shape = [2, 65739, 2, 4]
    x_cpu = torch.randn(shape, device='cpu', requires_grad=True)
    x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True)
    model = torch.nn.ReplicationPad2d((1, 1, 1, 1))

    out_cpu = model(x_cpu)
    out_mps = model(x_mps)

    # backward
    g_cpu = torch.randn_like(out_cpu)
    g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False)
    out_cpu.backward(g_cpu)
    out_mps.backward(g_mps)

    print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }")

    # Expected Output:
    # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0)

These tests produce expected output with this workaround.

Copy link

pytorch-bot bot commented Apr 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152094

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d64f0c6 with merge base ef1d45b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Apr 24, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Apr 24, 2025
@xwu-498 xwu-498 force-pushed the fix-pad-grad branch 2 times, most recently from 97ba670 to 9256e4e Compare April 24, 2025 23:34
@xwu-498 xwu-498 changed the title Work around MPSGraph issue in handling backward pass of nn.Replicatio… Work around MPSGraph issue in backward pass of nn.ReplicationPad1d/2d. Apr 24, 2025
@xwu-498 xwu-498 changed the title Work around MPSGraph issue in backward pass of nn.ReplicationPad1d/2d. Work around MPSGraph issue in backward pass of nn.ReplicationPad1d/2d Apr 24, 2025
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2025
@xwu-498 xwu-498 force-pushed the fix-pad-grad branch 2 times, most recently from ebd8f3a to 786fd71 Compare May 12, 2025 19:06
@skotapati skotapati added the ciflow/mps Run MPS tests (subset of trunk) label May 14, 2025
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix lint and also looks like it fails on exactly the test you are trying to add

// we break the tensor into chuncks where the problematic dimention is no greater than 2**16-1.
// This is reported in https://github.com/pytorch/pytorch/issues/135447.
// Internal radar for MPSGraph: rdar://149853787.
const int64_t max_sub_batch_size = 65535;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const int64_t max_sub_batch_size = 65535;
constexpr auto max_sub_batch_size = 65535;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @malfet for the comments. I will follow up on these issues.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @malfet, could you have another look?

I've made the following changes:

  • The change you suggested here: Work around MPSGraph issue in backward pass of nn.ReplicationPad1d/2d #152094 (comment).
  • Fixed lint
  • Changed test_ReplicationPad*_large in test_nn.py to be marked with @expectedFailureMPSPre15 instead of @expectedFailureMPS. The fix made in the PR addressed the issue with large dimensions. But it exposed another issue when OS version was older than 15. So we use @expectedFailureMPSPre15 to allow these tests to be validated on MacOS 15 and above.
  • Removed the test cases added to test_mps.py because there were already test cases in test_nn.py to exercise the large dimensions.

Thank you.

@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label May 17, 2025
@skotapati skotapati added ciflow/mps Run MPS tests (subset of trunk) keep-going Don't stop on first failure, keep running tests until the end labels May 19, 2025
@skotapati
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix-pad-grad onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix-pad-grad && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label May 19, 2025
@xwu-498 xwu-498 force-pushed the fix-pad-grad branch 2 times, most recently from b7e3c93 to 94564e6 Compare May 28, 2025 20:32
@skotapati skotapati added the ciflow/mps Run MPS tests (subset of trunk) label May 29, 2025
@skotapati
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

Fixes pytorch#135447.

When the 3rd from last dimension is 2^16 or greater, MPSGraph returns 0 for padgradient.
To work around this, we break the problematic dimension into chunks with chunk size being
no greater than 2^16 - 1.

Test case for nn.ReplicationPad1d:
```
    shape = [65739, 2, 4]
    x_cpu = torch.randn(shape, device='cpu', requires_grad=True)
    x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True)
    model = torch.nn.ReplicationPad1d((1, 1))

    out_cpu = model(x_cpu)
    out_mps = model(x_mps)

    # backward
    g_cpu = torch.randn_like(out_cpu)
    g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False)
    out_cpu.backward(g_cpu)
    out_mps.backward(g_mps)

    print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }")

    # Expected Output:
    # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0)
```

Test case for nn.ReplicationPad2d,
```
    shape = [2, 65739, 2, 4]
    x_cpu = torch.randn(shape, device='cpu', requires_grad=True)
    x_mps = x_cpu.clone().detach().to('mps').requires_grad_(True)
    model = torch.nn.ReplicationPad2d((1, 1, 1, 1))

    out_cpu = model(x_cpu)
    out_mps = model(x_mps)

    # backward
    g_cpu = torch.randn_like(out_cpu)
    g_mps = g_cpu.clone().detach().to('mps').requires_grad_(False)
    out_cpu.backward(g_cpu)
    out_mps.backward(g_mps)

    print(f"{((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = }")

    # Expected Output:
    # ((x_cpu.grad - x_mps.grad.cpu()).abs() > 1e-5).sum() = tensor(0)
```

These tests produce expected output with this workaround.
@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix-pad-grad onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix-pad-grad && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label May 29, 2025
@skotapati skotapati added the ciflow/mps Run MPS tests (subset of trunk) label May 29, 2025
@xwu-498 xwu-498 requested a review from malfet June 12, 2025 20:26
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) keep-going Don't stop on first failure, keep running tests until the end open source release notes: mps Release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MPS] Correctness issue in backward pass of nn.ReplicationPad1d and nn.ReplicationPad2d
6 participants