Skip to content

[Scaled MM] Update to support on B200 TN, NT, NN, TT Layouts are supported #152150

@drisspg

Description

@drisspg

Summary

On Sm100 w/ cuda 12.8 cublas supports all 4 variants. We should update our PerTensor scaling kernel to allow for these layouts.

We can also update our recipes in TorchAO to not require this data transposition. Since the MMA atom supports TN,NN,NT,NN we should also update our rowwise scaling kernel to not require this layout.

cc @msaroufim @jerryzh168 @ptrblck @eqy @yanbing-j @vkuzo @albanD @kadeng @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions