Skip to content

[WIP] Initial implementation of Grouped Gemm API #148531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed

[WIP] Initial implementation of Grouped Gemm API #148531

wants to merge 20 commits into from

Conversation

ngimel
Copy link
Collaborator

@ngimel ngimel commented Mar 5, 2025

This PR provides initial cutlass implementation of grouped gemm api as described in this document. Any combination of 2d and 3d inputs is supported, with 2d input being jagged, and the offsets of the jagged input being given by device tensor offs. Only H100 is supported, and only fp8_e4m3 with bf16 output and rowwise scaling. All the dimensions of each individual gemm have to be multiple of 16, that's cutlass limitation.
I'll need to add those checks, for dynamic dimensions unfortunately the checks will have to be a device assert.
I had to copy-paste cutlass's Sm90RowBroadcast and Sm90ColBroadcast structs with minor changes to enable scales given as pointer arrays, ideally those should be part of cutlass itself.
I copied the schedules from the similar grouped gemm in FBGEMM, but there's a lot of room to improve perf, especially for fast_accum=False.
Next steps would be perf tuning and increasing coverage to B100, I don't know how cutlass grouped gemm example handles blockwise scaling on B100.
cc @vkuzo @drisspg @lw

@ngimel ngimel requested review from eqy and syed-ahmed as code owners March 5, 2025 04:45
Copy link

pytorch-bot bot commented Mar 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148531

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 37a054c with merge base e0d4c43 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

github-actions bot commented Mar 5, 2025

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@@ -95,6 +95,26 @@ if(INTERN_BUILD_ATEN_OPS)
endif()
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")

set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu")
Copy link
Contributor

@drisspg drisspg Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated: It seems that these non portable arches are becoming more and more important we might want to figure out a more generalizable approach

bool use_fast_accum) {
#ifndef USE_ROCM
bool allowed_device = _scaled_mm_allowed_device();
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe remove the rocm part here

#include <c10/util/irange.h>

// Two warninngs in Cutlass included header files
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit/TODO: we should double check if these pragmas are still needed

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sprinkled in some random comments overall look good

Copy link

linux-foundation-easycla bot commented Mar 5, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@ngimel ngimel added the release notes: cuda release notes category label Mar 5, 2025
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 10, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@janeyx99
Copy link
Contributor

@pytorchbot revert -m "Sorry but this broke ROCm jobs on trunk" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@ngimel your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Mar 11, 2025
This reverts commit ff29791.

Reverted #148531 on behalf of https://github.com/janeyx99 due to Sorry but this broke ROCm jobs on trunk ([comment](#148531 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 11, 2025
@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@ngimel ngimel added the ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 label Mar 11, 2025
@ngimel
Copy link
Collaborator Author

ngimel commented Mar 11, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants