Skip to content

Commit 51ba5b5

Browse files
committed
[mlir] add lowering from affine.min to std
Summary: Affine minimum computation will be used in tiling transformation. The implementation is mostly boilerplate as we already lower the minimum in the upper bound of an affine loop. Differential Revision: https://reviews.llvm.org/D73488
1 parent e7e0437 commit 51ba5b5

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -271,20 +271,41 @@ Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
271271
builder);
272272
}
273273

274+
// Emit instructions that correspond to computing the minimum value amoung the
275+
// values of a (potentially) multi-output affine map applied to `operands`.
276+
static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map,
277+
ValueRange operands) {
278+
if (auto values =
279+
expandAffineMap(builder, loc, map, llvm::to_vector<4>(operands)))
280+
return buildMinMaxReductionSeq(loc, CmpIPredicate::slt, *values, builder);
281+
return nullptr;
282+
}
283+
274284
// Emit instructions that correspond to the affine map in the upper bound
275285
// applied to the respective operands, and compute the minimum value across
276286
// the results.
277287
Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
278-
SmallVector<Value, 8> boundOperands(op.getUpperBoundOperands());
279-
auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(),
280-
boundOperands);
281-
if (!ubValues)
282-
return nullptr;
283-
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues,
284-
builder);
288+
return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
289+
op.getUpperBoundOperands());
285290
}
286291

287292
namespace {
293+
class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
294+
public:
295+
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
296+
297+
PatternMatchResult matchAndRewrite(AffineMinOp op,
298+
PatternRewriter &rewriter) const override {
299+
Value reduced =
300+
lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
301+
if (!reduced)
302+
return matchFailure();
303+
304+
rewriter.replaceOp(op, reduced);
305+
return matchSuccess();
306+
}
307+
};
308+
288309
// Affine terminators are removed.
289310
class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> {
290311
public:
@@ -520,10 +541,19 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
520541

521542
void mlir::populateAffineToStdConversionPatterns(
522543
OwningRewritePatternList &patterns, MLIRContext *ctx) {
544+
// clang-format off
523545
patterns.insert<
524-
AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering,
525-
AffineLoadLowering, AffinePrefetchLowering, AffineStoreLowering,
526-
AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(ctx);
546+
AffineApplyLowering,
547+
AffineDmaStartLowering,
548+
AffineDmaWaitLowering,
549+
AffineLoadLowering,
550+
AffineMinLowering,
551+
AffinePrefetchLowering,
552+
AffineStoreLowering,
553+
AffineForLowering,
554+
AffineIfLowering,
555+
AffineTerminatorLowering>(ctx);
556+
// clang-format on
527557
}
528558

529559
namespace {

mlir/test/Transforms/lower-affine.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,18 @@ func @affine_dma_wait(%arg0 : index) {
590590
// CHECK-NEXT: dma_wait %0[%[[b]]], %c64 : memref<1xi32>
591591
return
592592
}
593+
594+
// CHECK-LABEL: func @affine_min
595+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
596+
func @affine_min(%arg0: index, %arg1: index) -> index{
597+
// CHECK: %[[Cm1:.*]] = constant -1
598+
// CHECK: %[[neg1:.*]] = muli %[[ARG1]], %[[Cm1:.*]]
599+
// CHECK: %[[first:.*]] = addi %[[ARG0]], %[[neg1]]
600+
// CHECK: %[[Cm2:.*]] = constant -1
601+
// CHECK: %[[neg2:.*]] = muli %[[ARG0]], %[[Cm2:.*]]
602+
// CHECK: %[[second:.*]] = addi %[[ARG1]], %[[neg2]]
603+
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[first]], %[[second]]
604+
// CHECK: select %[[cmp]], %[[first]], %[[second]]
605+
%0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
606+
return %0 : index
607+
}

0 commit comments

Comments
 (0)