-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[Draft][CUDA] Upgrade torch._scaled_grouped_mm to SM100+ #156806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156806
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 1 Unrelated FailureAs of commit b9529e1 with merge base e9d27aa ( NEW FAILURES - The following jobs have failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" "module: cuda" |
43b4d74
to
21d3935
Compare
C10_DIAGNOSTIC_POP() | ||
C10_DIAGNOSTIC_POP() | ||
|
||
namespace at::cuda::detail { | ||
|
||
GroupCountInfo get_group_count( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that the 2D/3D logic is used across multiple functions, so I consolidated it into a struct for better reuse and readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give a description of what scaling formats you indend to support?
typename LayoutSFA, | ||
typename LayoutSFB, | ||
typename ScaleConfig> | ||
__global__ void prepare_grouped_gemm_data_sm100( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this need to be separate, is there enough changes to justify that? prepare_grouped_gemm_data
has been changed recently to relax restrictions (e.g. cutlass 4 no longer requires K>0
aten/src/ATen/native/cuda/Blas.cpp
Outdated
scale.dim(), | ||
"D, arg ", | ||
arg_idx); | ||
mat_a.size(-1) % 128 == 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow 128 is a big multiplier, is there a reference for why it's required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tile shapes are large on sm90 too, that doesn't translate into input shape requirements
GroupMMInputMatrixType input_matrix_type; | ||
}; | ||
|
||
GroupCountInfo get_group_count( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function doesn't do what it says
type = GroupMMInputMatrixType::GroupMMInputMatrixType_MatrixA_2D_MatrixB_2D; | ||
|
||
// stack on the K dimension | ||
K = K / group_count; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is super confusing, why would you do that? What does this K even mean, an average K of grouped matrix multiply? Same for other average values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/ScaledGroupMM.cu#L436
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/ScaledGroupMM.cu#L236
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/GroupMM.cu#L331
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/GroupMM.cu#L186
To reduce duplication and improve maintainability, I’ve consolidated this logic into a single helper function.
aten/src/ATen/native/cuda/Blas.cpp
Outdated
TORCH_CHECK( | ||
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx); | ||
scale_a.dim() == 3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without a description of what you expect scale to be, expecting a 3d scale for a matrix that can be either 2d or 3d doesn't seem correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can yo make sure you are rebased on top offf:
#158037
there might be conflicts
aten/src/ATen/native/cuda/Blas.cpp
Outdated
TORCH_CHECK( | ||
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx); | ||
TORCH_CHECK( | ||
scale.size(0) == mat.size(dim) * ( (info.input_matrix_type == at::cuda::detail::GroupMMInputMatrixType::GroupMMInputMatrixType_MatrixA_2D_MatrixB_2D) ? info.group_count : 1), "scale must have the same length as mat for arg ", arg_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, maybe a using expression so that this line is easier to pasre
|
||
if (!transpose) { | ||
*layout_sfa_ptr = | ||
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is adding rowwise scaling support, not MX right?
namespace { | ||
|
||
using Strides = at::cuda::detail::Strides; | ||
|
||
int ceildiv(int a, int b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have these here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/ceil_div.h
type = GroupMMInputMatrixType::GroupMMInputMatrixType_MatrixA_2D_MatrixB_2D; | ||
|
||
// stack on the K dimension | ||
K = K / group_count; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/ScaledGroupMM.cu#L436
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/ScaledGroupMM.cu#L236
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/GroupMM.cu#L331
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/GroupMM.cu#L186
To reduce duplication and improve maintainability, I’ve consolidated this logic into a single helper function.
aten/src/ATen/native/cuda/Blas.cpp
Outdated
scale.dim(), | ||
"D, arg ", | ||
arg_idx); | ||
mat_a.size(-1) % 128 == 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sm90 utilizes a 1D tensor for scale factors. It performs row-wise and column-wise broadcasting across the A and B matrices, see code here. However, sm100 requires a more explicit scale factor layout, as documented here I've tried to achieve compatibility with the existing sm90 scale factor tensor shape, I think Sm100BlockwiseScaleConfig<1,1,128,MN,MN> is the most closely mimicked row-wise broadcasting as in sm90. However, but my performance benchmarks revealed that this ScaleConfig leads to performance penalty. Thus, I bascially use the ScaleConfig from sgl-kernel for now, but this requires 128 size factor as you see in the code. Other layout may need to tuned carefully to ensure the max performance. |
#157950, This might provide more clarity, I am curious as to what specific scaling strategy this adds support for on blackwell |
I think the scale config for current implementation is scale_a: [M, K//128], scale_b: [K//128, N//128] |
f9ba45b
to
567e707
Compare
567e707
to
ec3835a
Compare
|
||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") | ||
@unittest.skipIf(not SM100OrLater, "Grouped gemm supported on SM100") | ||
def test_scaled_grouped_gemm_3d_3d_sm100(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_scaled_mm is not yet available for SM100, the tests use a block-wise emulated reference implementation for comparison.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on which version (scaling strategy + input types) is not supported on sm100 we should have pretty good coverage now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then I question the wisdom of enabling DeepSeek-like scaling for _scaled_grouped_mm on Blackwell if just _scaled_mm is not supported. Should we start by supporting mx scales in grouped mm?
This PR adds support for the Blackwell-specific scaling factor layouts, following the Scaling Factor Types outlined in issues #157950 and #158037. The operator now supports the following configurations:
GroupMMInputMatrixType and a helper function get_group_info have been introduced to generalize the logic for handling various input tensor combinations (2D/3D, representing batched or ragged inputs). This abstracts away the complexity of determining matrix dimensions (M, N, K) and group counts for different scenarios.
The refactoring supports ragged inputs (as introduced in PR [WIP] Initial implementation of Grouped Gemm API #148531). For the initial Blackwell implementation with ragged tensors, the scaling factors are expected to be 3D tensors with shapes like [n_groups, M, K//128] for scale_A and [n_groups, K//128, N//128] for scale_B.
As discussed in issue Upgrade torch._scaled_grouped_mm to SM100+ #156238, we are going to support more flexible input layouts in the future.
_scaled_mm is not yet available for SM100, the tests use a block-wise emulated reference implementation for comparison.