Skip to content

Conversation

Muzammiluddin-Syed-ECE
Copy link
Contributor

The ScaledMFMAOp accepts scales as a vector of 4 bytes (vector<4xf8E8M0FNU>) that can be stored in a single register with a particular scale accessed using the OpSel attribute. Currently, we only use one byte in this 4-byte vector, resulting in 3 wasted registers.

This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors.

Example:

  %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
  %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
  amdgpu.scaled_mfma(%scale[0] * ...

to

  %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU> 
  %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
  amdgpu.scaled_mfma(%scale[0-3] * ...

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-amdgpu

@llvm/pr-subscribers-backend-amdgpu

Author: Muzammil (Muzammiluddin-Syed-ECE)

Changes

The ScaledMFMAOp accepts scales as a vector of 4 bytes (vector&lt;4xf8E8M0FNU&gt;) that can be stored in a single register with a particular scale accessed using the OpSel attribute. Currently, we only use one byte in this 4-byte vector, resulting in 3 wasted registers.

This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors.

Example:

  %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector&lt;?x?x?xf8E8M0FNU&gt;
  %scale = vector.insert %unit, ... : f8E8M0FNU into vector&lt;4xf8E8M0FNU&gt;
  amdgpu.scaled_mfma(%scale[0] * ...

to

  %reshaped = vector.shape_cast %ScaleSrc : vector&lt;?x?x?xf8E8M0FNU&gt; to vector&lt;?x4xf8E8M0FNU&gt; 
  %scale = vector.extract %reshaped[?] : vector&lt;4xf8E8M0FNU&gt; from vector&lt;?x4xf8E8M0FNU&gt;
  amdgpu.scaled_mfma(%scale[0-3] * ...

Full diff: https://github.com/llvm/llvm-project/pull/155951.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+142)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt (+1)
  • (modified) mlir/test/Dialect/AMDGPU/canonicalize.mlir (+25)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2ccf350a359a8..a24a918357f2d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
     attr-dict
     `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
   }];
+  let hasCanonicalizer = 1;
 }
 #endif // AMDGPU
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 11a40d663a201..4107ec53a0988 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <cstdint>
 #include <limits>
 #include <optional>
 
@@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ScaledMFMAOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // If this use of a scale has a non zero opsel, packing has already been
+    // done.
+    auto checkIfUnpackable = [&](OpOperand &op) {
+      if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+        switch (op.getOperandNumber()) {
+        case 3:
+          return smfma.getScalesIdxA() != 0;
+          break;
+        case 4:
+          return smfma.getScalesIdxB() != 0;
+          break;
+        default:
+          return true;
+          break;
+        }
+      }
+    };
+
+    auto setOpsel = [&](unsigned idx, int64_t val) {
+      switch (idx) {
+      case 3:
+        return op.setScalesIdxA(val);
+        break;
+      case 4:
+        return op.setScalesIdxB(val);
+        break;
+      default:
+        break;
+      }
+    };
+
+    // Obtain flat index from offsets and shape.
+    auto getIdxFromExtract = [](vector::ExtractOp op) {
+      ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+      int cumul = 1;
+      int idx = 0;
+      for (auto [offset, size] :
+           reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
+        idx += offset * cumul;
+        cumul *= size;
+      }
+      return idx;
+    };
+
+    // Obtain offsets for new shape from flat index.
+    auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
+      SmallVector<int64_t> res;
+      ShapedType shapedty = static_cast<ShapedType>(ty);
+      int64_t numElements = shapedty.getNumElements();
+      for (auto size : shapedty.getShape()) {
+        numElements /= size;
+        res.push_back(idx / numElements);
+        idx -= (idx / numElements) * size;
+      }
+      return res;
+    };
+
+    // For every scale operand of this ScaledMFMAOp, if the scale follows the
+    // following pattern:
+    //
+    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
+    // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
+    // amdgpu.scaled_mfma(%scale[0] * ...
+    //
+    // rewrite to:
+    //
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
+    // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+    // amdgpu.scaled_mfma(%scale[0-3] * ...
+    //
+    // This creates duplicate shape_casts for every use but these will be removed in CSE.
+    for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+      auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+      if (!insertOp) {
+        return failure();
+      }
+      if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
+        return failure();
+      }
+
+      auto extractOp =
+          insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
+      if (!extractOp) {
+        return failure();
+      }
+
+      Value scaleSrc = extractOp.getOperand(0);
+      auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
+      if (!stype) {
+        return failure();
+      }
+      // We do not handle dynamic dims yet, assume that the input is padded to
+      // a static shape now.
+      if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
+                       [&](int64_t i) { return stype.isDynamicDim(i); })) {
+        return failure();
+      }
+
+      int64_t numElements = stype.getNumElements();
+      if (numElements <= 4) {
+        return failure();
+      }
+
+      Type newSrcType = VectorType::get(
+          SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
+      Value newScaleSrc =
+          rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
+      int64_t idx = getIdxFromExtract(extractOp);
+      SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
+      auto scaleTy = VectorType::get({4}, stype.getElementType());
+      Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
+          SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
+      Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
+      op.setOperand(opIdx, scale);
+      setOpsel(opIdx, offsets[1]);
+    }
+    return success();
+  }
+};
+} // namespace
+
+void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  results.add<PackScales>(context);
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 2a019954c8356..5d14a05945e95 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
   MLIRROCDLDialect
   # Needed for GPU address space enum definition
   MLIRGPUDialect
+  MLIRVectorDialect
   MLIRIR
   MLIRSideEffectInterfaces
   MLIRMemRefUtils
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 5501ad42dbd90..75cbf29c95f29 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
     : f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma
+// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+  %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
+}

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4107ec53a..8b044669c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -706,17 +706,19 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
     // For every scale operand of this ScaledMFMAOp, if the scale follows the
     // following pattern:
     //
-    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
-    // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
-    // amdgpu.scaled_mfma(%scale[0] * ...
+    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
+    // vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
+    // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
     //
     // rewrite to:
     //
-    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
-    // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
+    // vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
+    // vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
     // amdgpu.scaled_mfma(%scale[0-3] * ...
     //
-    // This creates duplicate shape_casts for every use but these will be removed in CSE.
+    // This creates duplicate shape_casts for every use but these will be
+    // removed in CSE.
     for (auto opIdx : SmallVector<int64_t>({3, 4})) {
       auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
       if (!insertOp) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants