diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index f497d2db3bf7c..ab57557f3f13d 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -518,10 +518,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern { Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); - if (!isa(operandETy)) - operand = b.create(f32Ty, operand); if (!isa(resultETy)) return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN"); + if (!isa(operandETy)) + operand = b.create(f32Ty, operand); Value c0x1 = createConst(loc, i4Ty, 1, rewriter); Value c0x3 = createConst(loc, i4Ty, 3, rewriter); @@ -657,6 +657,7 @@ struct ScalingExtFOpConverter : public OpRewritePattern { scaleOperand = b.create(scaleTy, scaleOperand, nullptr, op.getFastmathAttr()); } + // Catch scale types like f8E5M2. if (!llvm::isa(scaleETy)) { return rewriter.notifyMatchFailure( op, "scaling_extf is using scales of type which can not be converted " @@ -777,7 +778,7 @@ struct ArithExpandOpsPass if (includeBf16) legalTypes &= !(inETy.isF32() && outETy.isBF16()); if (includeF8E8M0) - legalTypes &= !(llvm::isa(outETy)); + legalTypes &= !(llvm::isa(outETy)); if (includeF4E2M1) legalTypes &= !llvm::isa(outETy); return legalTypes; @@ -832,7 +833,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { MaximumMinimumFOpConverter, MaximumMinimumFOpConverter, MaxNumMinNumFOpConverter, - MaxNumMinNumFOpConverter + MaxNumMinNumFOpConverter >(patterns.getContext()); // clang-format on }