Skip to content

[cuDNN] cuDNN frontend for LayerNorm RMSNorm #159682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

AaronWang04
Copy link
Contributor

cuDNN performance as of 9.10 is still not great

if it ever becomes good, this PR could streamline process of adding experimental cudnn backend for layernorm and rmsnorm

Follows examples at: https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/norm

Copy link

pytorch-bot bot commented Aug 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159682

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 2 Unrelated Failures

As of commit 6030457 with merge base aeb5321 (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

github-actions bot commented Aug 1, 2025

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Skylion007
Copy link
Collaborator

FYI: @eqy had a stale PR doing it

@eqy
Copy link
Collaborator

eqy commented Aug 3, 2025

haha yes this is based on that

};

void setLayerNormParams(LayerNormParams& params, const Tensor& X, int64_t M, int64_t N) {
memset(&params, 0, sizeof(params));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use stdlib std::memset

{Y_fe, Y->data_ptr()}};
variant_pack = std::move(variant_pack_);
auto result = std::make_tuple(layernorm_graph, X_fe, mean_fe, inv_variance_fe, scale_fe, bias_fe, Y_fe);
layernorm_forward_graph_cache.update(key, result);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can result not be moved here too?

{DX_fe, dX->data_ptr()}};
variant_pack = std::move(variant_pack_);
auto result = std::make_tuple(layernorm_graph, X_fe, DY_fe, mean_fe, inv_variance_fe, scale_fe, dscale_fe, dbias_fe, DX_fe);
layernorm_backward_graph_cache.update(key, result);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

};

void setRMSNormParams(RMSNormParams& params, const Tensor& X, int64_t M, int64_t N) {
memset(&params, 0, sizeof(params));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use std

return &(it->second);
}

void update(const KeyType& key, T& results) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use perfect forwarding, no?

{DX_fe, dX->data_ptr()}};
variant_pack = std::move(variant_pack_);
auto result = std::make_tuple(rmsnorm_graph, X_fe, DY_fe, inv_variance_fe, scale_fe, dscale_fe, DX_fe);
rmsnorm_backward_graph_cache.update(key, result);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise

@AaronWang04
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnn_layer_norm onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnn_layer_norm && git pull --rebase)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants