Skip to content

autograd: Add VJP and JVP rules for aten::aminmax #158241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

vijayabhaskar-ev
Copy link

@vijayabhaskar-ev vijayabhaskar-ev commented Jul 14, 2025

Adds functionally correct backward (VJP) and forward (JVP) autograd rules for the aten::aminmax operator to derivatives.yaml using existing helper functions. This ensures correct eager mode differentiation.

Fixes #148808

Copy link

pytorch-bot bot commented Jul 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158241

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit b7818f6 with merge base c8c221c (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@albanD albanD removed their request for review July 14, 2025 18:16
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 14, 2025
Adds functionally correct backward (VJP) and forward (JVP) autograd
rules for the aten::aminmax operator to derivatives.yaml using existing
helper functions. This ensures correct eager mode differentiation.

Fixes pytorch#148808
… to handle amin and amax correctly.

- Modified derivatives.yaml to reflect the changes in autograd behavior.
- Updated test cases in test_autograd.py to validate the updated rules.
- Updated common_methods_invocations.py to include amin and amax in relevant test cases.
Makes backward pass more efficient by:
- Avoiding zeros tensor creation when only one gradient defined
- Using in-place operations where possible
• Added a testcase to validate the new logic.
@vijayabhaskar-ev
Copy link
Author

@soulitzer Adding restore_reduced_dims on grad_max and grad_min is causing Segmentation fault (core dumped)

@vijayabhaskar-ev vijayabhaskar-ev force-pushed the fix-aminmax-autograd-rules branch from 4c871fc to 889bbf3 Compare July 16, 2025 15:57
@vijayabhaskar-ev
Copy link
Author

@soulitzer Adding restore_reduced_dims on grad_max and grad_min is causing Segmentation fault (core dumped)

Removed restore_reduced_dims on grad_min/grad_max in aminmax_backward.
scale_grad_by_count already handles broadcasting and the extra expansion created zero stride views that caused gradcheck/functorch crashes. Reverting to the earlier logic fixes the seg fault without altering gradient results.

@soulitzer
Copy link
Contributor

scale_grad_by_count already handles broadcasting

Broadcasting doesn't work when certain dimensions are missing. restore_reduced_dims is responsible for restoring those dimensions so broadcasting can happen.

@vijayabhaskar-ev
Copy link
Author

scale_grad_by_count already handles broadcasting

Broadcasting doesn't work when certain dimensions are missing. restore_reduced_dims is responsible for restoring those dimensions so broadcasting can happen.

If I add restore_reduced_dims on grad_max and grad_min, it's causing Segmentation fault (core dumped).
auto max_reduced = restore_reduced_dims(max, dims, keepdim); auto max_mask = (self == max_reduced); auto grad_max_expanded = restore_reduced_dims(grad_max, dims, keepdim); auto grad_max_result = scale_grad_by_count(grad_max_expanded, max_mask, dims);

 This is the commit that caused the seg fault : https://github.com/pytorch/pytorch/pull/158241/commits/c1ad91c9244a691b5e574f04724587574a64f52a

@vijayabhaskar-ev
Copy link
Author

scale_grad_by_count already handles broadcasting

Broadcasting doesn't work when certain dimensions are missing. restore_reduced_dims is responsible for restoring those dimensions so broadcasting can happen.

If I add restore_reduced_dims on grad_max and grad_min, it's causing Segmentation fault (core dumped). auto max_reduced = restore_reduced_dims(max, dims, keepdim); auto max_mask = (self == max_reduced); auto grad_max_expanded = restore_reduced_dims(grad_max, dims, keepdim); auto grad_max_result = scale_grad_by_count(grad_max_expanded, max_mask, dims);

 This is the commit that caused the seg fault : https://github.com/pytorch/pytorch/pull/158241/commits/c1ad91c9244a691b5e574f04724587574a64f52a

Let me rebuild and try again

@vijayabhaskar-ev
Copy link
Author

scale_grad_by_count already handles broadcasting

Broadcasting doesn't work when certain dimensions are missing. restore_reduced_dims is responsible for restoring those dimensions so broadcasting can happen.

If I add restore_reduced_dims on grad_max and grad_min, it's causing Segmentation fault (core dumped). auto max_reduced = restore_reduced_dims(max, dims, keepdim); auto max_mask = (self == max_reduced); auto grad_max_expanded = restore_reduced_dims(grad_max, dims, keepdim); auto grad_max_result = scale_grad_by_count(grad_max_expanded, max_mask, dims);

 This is the commit that caused the seg fault : https://github.com/pytorch/pytorch/pull/158241/commits/c1ad91c9244a691b5e574f04724587574a64f52a

Let me rebuild and try again

@soulitzer PR updated.

auto grad_min_full =
restore_reduced_dims(grad_min, dims, keepdim)
.expand_as(min_mask)
.contiguous();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't look right, we don't want a copy

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed contiguous() as it cause's extra memory footprint

@vijayabhaskar-ev
Copy link
Author

@soulitzer Can you review this PR? Resolved the issue's mentioned in the comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Inductor] Inference failed with the compiled model with aminmax operator
4 participants