-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[WIP] Initial implementation of Grouped Gemm API #148531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148531
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 37a054c with merge base e0d4c43 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
@@ -95,6 +95,26 @@ if(INTERN_BUILD_ATEN_OPS) | |||
endif() | |||
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS) | |||
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}") | |||
|
|||
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated: It seems that these non portable arches are becoming more and more important we might want to figure out a more generalizable approach
bool use_fast_accum) { | ||
#ifndef USE_ROCM | ||
bool allowed_device = _scaled_mm_allowed_device(); | ||
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe remove the rocm part here
#include <c10/util/irange.h> | ||
|
||
// Two warninngs in Cutlass included header files | ||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit/TODO: we should double check if these pragmas are still needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sprinkled in some random comments overall look good
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot revert -m "Sorry but this broke ROCm jobs on trunk" -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
@ngimel your PR has been successfully reverted. |
This reverts commit ff29791. Reverted #148531 on behalf of https://github.com/janeyx99 due to Sorry but this broke ROCm jobs on trunk ([comment](#148531 (comment)))
@ngimel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR provides initial cutlass implementation of grouped gemm api as described in this document. Any combination of 2d and 3d inputs is supported, with 2d input being jagged, and the offsets of the jagged input being given by device tensor
offs
. Only H100 is supported, and only fp8_e4m3 with bf16 output and rowwise scaling. All the dimensions of each individual gemm have to be multiple of 16, that's cutlass limitation.I'll need to add those checks, for dynamic dimensions unfortunately the checks will have to be a device assert.
I had to copy-paste cutlass's
Sm90RowBroadcast
andSm90ColBroadcast
structs with minor changes to enable scales given as pointer arrays, ideally those should be part of cutlass itself.I copied the schedules from the similar grouped gemm in FBGEMM, but there's a lot of room to improve perf, especially for
fast_accum=False
.Next steps would be perf tuning and increasing coverage to B100, I don't know how cutlass grouped gemm example handles blockwise scaling on B100.
cc @vkuzo @drisspg @lw