diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 2ccf350a359a8..a24a918357f2d 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp : attr-dict `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC) }]; + let hasCanonicalizer = 1; } #endif // AMDGPU diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 11a40d663a201..4107ec53a0988 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -28,6 +29,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" +#include #include #include @@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ScaledMFMAOp +//===----------------------------------------------------------------------===// + +namespace { +/// Check if the scales input is used in other scaled mfma's while they exist. +/// If theyre unused then pack the scales. +struct PackScales final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ScaledMFMAOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // If this use of a scale has a non zero opsel, packing has already been + // done. + auto checkIfUnpackable = [&](OpOperand &op) { + if (auto smfma = dyn_cast(op.getOwner())) { + switch (op.getOperandNumber()) { + case 3: + return smfma.getScalesIdxA() != 0; + break; + case 4: + return smfma.getScalesIdxB() != 0; + break; + default: + return true; + break; + } + } + }; + + auto setOpsel = [&](unsigned idx, int64_t val) { + switch (idx) { + case 3: + return op.setScalesIdxA(val); + break; + case 4: + return op.setScalesIdxB(val); + break; + default: + break; + } + }; + + // Obtain flat index from offsets and shape. + auto getIdxFromExtract = [](vector::ExtractOp op) { + ShapedType ty = dyn_cast(op.getOperand(0).getType()); + int cumul = 1; + int idx = 0; + for (auto [offset, size] : + reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) { + idx += offset * cumul; + cumul *= size; + } + return idx; + }; + + // Obtain offsets for new shape from flat index. + auto getOffsetsFromIdx = [](int64_t idx, Type ty) { + SmallVector res; + ShapedType shapedty = static_cast(ty); + int64_t numElements = shapedty.getNumElements(); + for (auto size : shapedty.getShape()) { + numElements /= size; + res.push_back(idx / numElements); + idx -= (idx / numElements) * size; + } + return res; + }; + + // For every scale operand of this ScaledMFMAOp, if the scale follows the + // following pattern: + // + // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector + // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU> + // amdgpu.scaled_mfma(%scale[0] * ... + // + // rewrite to: + // + // %reshaped = vector.shape_cast %ScaleSrc : vector to vector + // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector + // amdgpu.scaled_mfma(%scale[0-3] * ... + // + // This creates duplicate shape_casts for every use but these will be removed in CSE. + for (auto opIdx : SmallVector({3, 4})) { + auto insertOp = op.getOperand(opIdx).getDefiningOp(); + if (!insertOp) { + return failure(); + } + if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) { + return failure(); + } + + auto extractOp = + insertOp.getOperand(0).getDefiningOp(); + if (!extractOp) { + return failure(); + } + + Value scaleSrc = extractOp.getOperand(0); + auto stype = dyn_cast(scaleSrc.getType()); + if (!stype) { + return failure(); + } + // We do not handle dynamic dims yet, assume that the input is padded to + // a static shape now. + if (llvm::any_of(llvm::seq(0, stype.getRank()), + [&](int64_t i) { return stype.isDynamicDim(i); })) { + return failure(); + } + + int64_t numElements = stype.getNumElements(); + if (numElements <= 4) { + return failure(); + } + + Type newSrcType = VectorType::get( + SmallVector({numElements / 4, 4}), stype.getElementType()); + Value newScaleSrc = + rewriter.create(loc, newSrcType, scaleSrc); + int64_t idx = getIdxFromExtract(extractOp); + SmallVector offsets(getOffsetsFromIdx(idx, newSrcType)); + auto scaleTy = VectorType::get({4}, stype.getElementType()); + Value extract = rewriter.create( + loc, newScaleSrc, SmallVector{offsets[0], 0}, + SmallVector{1, 4}, SmallVector{1, 1}); + Value scale = rewriter.create(loc, scaleTy, extract); + op.setOperand(opIdx, scale); + setOpsel(opIdx, offsets[1]); + } + return success(); + } +}; +} // namespace + +void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt index 2a019954c8356..5d14a05945e95 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect MLIRROCDLDialect # Needed for GPU address space enum definition MLIRGPUDialect + MLIRVectorDialect MLIRIR MLIRSideEffectInterfaces MLIRMemRefUtils diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir index 5501ad42dbd90..75cbf29c95f29 100644 --- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir +++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir @@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: : f32, memref<128x72xf32, 1>, memref func.return } + +// ----- + +// CHECK-LABEL: func @scaled_mfma +// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}} +// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU> +// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}} +func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32> + %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU> + %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU> + %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU> + %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU> + %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU> + %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU> + %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> + return %res_0, %res_1 : vector<4xf32>, vector<4xf32> +}