Skip to content

[Draft][CUDA] Upgrade torch._scaled_grouped_mm to SM100+ #156806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

eee4017
Copy link
Collaborator

@eee4017 eee4017 commented Jun 25, 2025

This PR adds support for the Blackwell-specific scaling factor layouts, following the Scaling Factor Types outlined in issues #157950 and #158037. The operator now supports the following configurations:

Operation API Input A Input B Scaling Factor A Scaling Factor B Hardware Limitations
Hopper Grouped GEMM (FP8) _scaled_grouped_mm 2D/3D Layout: RowMajor 2D/3D Layout: ColMajor shape: [M, 1] Layout: Per-row only shape: [1, N] Layout: Per-row only 90 FP8_E4M3 only, no bias/scale_result
Blackwell Grouped GEMM (FP8) _scaled_grouped_mm 2D/3D Layout: RowMajor 2D/3D Layout:ColMajor shape: [M, K//128] Layout: BlockWise1x128 Outer-dim-major shape: [K//128, N//128] Layout: BlockWise128x128 near-inner-dim-major 100+ FP8_E4M3 only, no bias/scale_result
  • GroupMMInputMatrixType and a helper function get_group_info have been introduced to generalize the logic for handling various input tensor combinations (2D/3D, representing batched or ragged inputs). This abstracts away the complexity of determining matrix dimensions (M, N, K) and group counts for different scenarios.

    • Note: When inputs are 2D (ragged), the returned M, N, and K dimensions are derived from the total tensor size and group count. These values are used for heuristics and may not represent the actual dimensions of each individual matrix in the group. See dispatch_fp8_grouped_gemm_on_tile_size
  • The refactoring supports ragged inputs (as introduced in PR [WIP] Initial implementation of Grouped Gemm API #148531). For the initial Blackwell implementation with ragged tensors, the scaling factors are expected to be 3D tensors with shapes like [n_groups, M, K//128] for scale_A and [n_groups, K//128, N//128] for scale_B.

  • As discussed in issue Upgrade torch._scaled_grouped_mm to SM100+ #156238, we are going to support more flexible input layouts in the future.

  • _scaled_mm is not yet available for SM100, the tests use a block-wise emulated reference implementation for comparison.

@eee4017 eee4017 requested review from eqy and syed-ahmed as code owners June 25, 2025 09:25
Copy link

pytorch-bot bot commented Jun 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156806

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 1 Unrelated Failure

As of commit b9529e1 with merge base e9d27aa (image):

NEW FAILURES - The following jobs have failed:

  • pull / cuda12.8-py3.10-gcc9-sm75 / build (gh)
    /var/lib/jenkins/workspace/aten/src/ATen/native/cuda/ScaledGroupMM.cu:621:7: error: typedef ‘using DtypeProblemShape = using UnderlyingProblemShape = struct cute::tuple<int, int, int>’ locally defined but not used [-Werror=unused-local-typedefs]
  • pull / linux-jammy-cuda12.8-py3.10-gcc11 / build (gh)
    /var/lib/jenkins/workspace/aten/src/ATen/native/cuda/ScaledGroupMM.cu:621:7: error: typedef ‘using DtypeProblemShape = using UnderlyingProblemShape = struct cute::tuple<int, int, int>’ locally defined but not used [-Werror=unused-local-typedefs]
  • pull / linux-jammy-cuda12.8-py3.10-gcc11-build-distributed / build (gh)
    /var/lib/jenkins/workspace/aten/src/ATen/native/cuda/ScaledGroupMM.cu:621:7: error: typedef ‘using DtypeProblemShape = using UnderlyingProblemShape = struct cute::tuple<int, int, int>’ locally defined but not used [-Werror=unused-local-typedefs]

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eee4017
Copy link
Collaborator Author

eee4017 commented Jun 25, 2025

@pytorchbot label "topic: not user facing" "module: cuda"

@pytorch-bot pytorch-bot bot added module: cuda Related to torch.cuda, and CUDA support in general topic: not user facing topic category labels Jun 25, 2025
@eee4017 eee4017 marked this pull request as draft June 25, 2025 09:30
@eee4017 eee4017 force-pushed the fralin/scaled_grouped_mm_sm100 branch from 43b4d74 to 21d3935 Compare July 9, 2025 13:15
@eee4017 eee4017 marked this pull request as ready for review July 9, 2025 13:16
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()

namespace at::cuda::detail {

GroupCountInfo get_group_count(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the 2D/3D logic is used across multiple functions, so I consolidated it into a struct for better reuse and readability.

@mikaylagawarecki mikaylagawarecki requested a review from ngimel July 14, 2025 15:02
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 14, 2025
@drisspg drisspg self-requested a review July 17, 2025 19:59
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a description of what scaling formats you indend to support?

typename LayoutSFA,
typename LayoutSFB,
typename ScaleConfig>
__global__ void prepare_grouped_gemm_data_sm100(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be separate, is there enough changes to justify that? prepare_grouped_gemm_data has been changed recently to relax restrictions (e.g. cutlass 4 no longer requires K>0

scale.dim(),
"D, arg ",
arg_idx);
mat_a.size(-1) % 128 == 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow 128 is a big multiplier, is there a reference for why it's required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tile shapes are large on sm90 too, that doesn't translate into input shape requirements

GroupMMInputMatrixType input_matrix_type;
};

GroupCountInfo get_group_count(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function doesn't do what it says

type = GroupMMInputMatrixType::GroupMMInputMatrixType_MatrixA_2D_MatrixB_2D;

// stack on the K dimension
K = K / group_count;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super confusing, why would you do that? What does this K even mean, an average K of grouped matrix multiply? Same for other average values

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TORCH_CHECK(
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
scale_a.dim() == 3,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without a description of what you expect scale to be, expecting a 3d scale for a matrix that can be either 2d or 3d doesn't seem correct

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can yo make sure you are rebased on top offf:
#158037

there might be conflicts

TORCH_CHECK(
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(dim) * ( (info.input_matrix_type == at::cuda::detail::GroupMMInputMatrixType::GroupMMInputMatrixType_MatrixA_2D_MatrixB_2D) ? info.group_count : 1), "scale must have the same length as mat for arg ", arg_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, maybe a using expression so that this line is easier to pasre


if (!transpose) {
*layout_sfa_ptr =
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is adding rowwise scaling support, not MX right?

namespace {

using Strides = at::cuda::detail::Strides;

int ceildiv(int a, int b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type = GroupMMInputMatrixType::GroupMMInputMatrixType_MatrixA_2D_MatrixB_2D;

// stack on the K dimension
K = K / group_count;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale.dim(),
"D, arg ",
arg_idx);
mat_a.size(-1) % 128 == 0,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eee4017
Copy link
Collaborator Author

eee4017 commented Jul 17, 2025

sm90 utilizes a 1D tensor for scale factors. It performs row-wise and column-wise broadcasting across the A and B matrices, see code here. However, sm100 requires a more explicit scale factor layout, as documented here

I've tried to achieve compatibility with the existing sm90 scale factor tensor shape, I think Sm100BlockwiseScaleConfig<1,1,128,MN,MN> is the most closely mimicked row-wise broadcasting as in sm90. However, but my performance benchmarks revealed that this ScaleConfig leads to performance penalty.

Thus, I bascially use the ScaleConfig from sgl-kernel for now, but this requires 128 size factor as you see in the code. Other layout may need to tuned carefully to ensure the max performance.

@drisspg
Copy link
Contributor

drisspg commented Jul 18, 2025

#157950, This might provide more clarity, I am curious as to what specific scaling strategy this adds support for on blackwell

@eee4017
Copy link
Collaborator Author

eee4017 commented Jul 18, 2025

I think the scale config for current implementation is scale_a: [M, K//128], scale_b: [K//128, N//128]

@eee4017 eee4017 force-pushed the fralin/scaled_grouped_mm_sm100 branch from 567e707 to ec3835a Compare August 6, 2025 15:06

@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(not SM100OrLater, "Grouped gemm supported on SM100")
def test_scaled_grouped_gemm_3d_3d_sm100(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_scaled_mm is not yet available for SM100, the tests use a block-wise emulated reference implementation for comparison.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on which version (scaling strategy + input types) is not supported on sm100 we should have pretty good coverage now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@ngimel ngimel Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I question the wisdom of enabling DeepSeek-like scaling for _scaled_grouped_mm on Blackwell if just _scaled_mm is not supported. Should we start by supporting mx scales in grouped mm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants