-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Open
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesFor torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesmodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Summary
On Sm100 w/ cuda 12.8 cublas supports all 4 variants. We should update our PerTensor scaling kernel to allow for these layouts.
We can also update our recipes in TorchAO to not require this data transposition. Since the MMA atom supports TN,NN,NT,NN we should also update our rowwise scaling kernel to not require this layout.
cc @msaroufim @jerryzh168 @ptrblck @eqy @yanbing-j @vkuzo @albanD @kadeng @penguinwu
vkuzo
Metadata
Metadata
Assignees
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: floatx (formerly float8)For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesFor torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float typesmodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module