Skip to content

Commit fdcecef

Browse files
author
Stephan Herhut
committed
Add lowering for loop.parallel to cfg.
Summary: This also removes the explicit pattern for loop.terminator to ensure that the terminator is only erased if the parent op is rewritten. Reductions are not yet supported. Reviewers: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73348
1 parent 8ed47b7 commit fdcecef

File tree

3 files changed

+82
-8
lines changed

3 files changed

+82
-8
lines changed

mlir/include/mlir/Dialect/LoopOps/LoopOps.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,13 @@ def ParallelOp : Loop_Op<"parallel",
177177
Variadic<Index>:$step);
178178
let results = (outs Variadic<AnyType>:$results);
179179
let regions = (region SizedRegion<1>:$body);
180+
181+
let extraClassDeclaration = [{
182+
iterator_range<Block::args_iterator> getInductionVars() {
183+
Block &block = body().front();
184+
return {block.args_begin(), block.args_end()};
185+
}
186+
}];
180187
}
181188

182189
def ReduceOp : Loop_Op<"reduce", [HasParent<"ParallelOp">]> {

mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
1515
#include "mlir/Dialect/LoopOps/LoopOps.h"
1616
#include "mlir/Dialect/StandardOps/Ops.h"
17+
#include "mlir/IR/BlockAndValueMapping.h"
1718
#include "mlir/IR/Builders.h"
1819
#include "mlir/IR/MLIRContext.h"
1920
#include "mlir/IR/Module.h"
@@ -142,14 +143,11 @@ struct IfLowering : public OpRewritePattern<IfOp> {
142143
PatternRewriter &rewriter) const override;
143144
};
144145

145-
struct TerminatorLowering : public OpRewritePattern<TerminatorOp> {
146-
using OpRewritePattern<TerminatorOp>::OpRewritePattern;
146+
struct ParallelLowering : public OpRewritePattern<mlir::loop::ParallelOp> {
147+
using OpRewritePattern<mlir::loop::ParallelOp>::OpRewritePattern;
147148

148-
PatternMatchResult matchAndRewrite(TerminatorOp op,
149-
PatternRewriter &rewriter) const override {
150-
rewriter.eraseOp(op);
151-
return matchSuccess();
152-
}
149+
PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
150+
PatternRewriter &rewriter) const override;
153151
};
154152
} // namespace
155153

@@ -178,6 +176,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
178176
// Append the induction variable stepping logic to the last body block and
179177
// branch back to the condition block. Construct an expression f :
180178
// (x -> x+step) and apply this expression to the induction variable.
179+
rewriter.eraseOp(lastBodyBlock->getTerminator());
181180
rewriter.setInsertionPointToEnd(lastBodyBlock);
182181
auto step = forOp.step();
183182
auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
@@ -220,6 +219,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
220219
// place it before the continuation block, and branch to it.
221220
auto &thenRegion = ifOp.thenRegion();
222221
auto *thenBlock = &thenRegion.front();
222+
rewriter.eraseOp(thenRegion.back().getTerminator());
223223
rewriter.setInsertionPointToEnd(&thenRegion.back());
224224
rewriter.create<BranchOp>(loc, continueBlock);
225225
rewriter.inlineRegionBefore(thenRegion, continueBlock);
@@ -231,6 +231,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
231231
auto &elseRegion = ifOp.elseRegion();
232232
if (!elseRegion.empty()) {
233233
elseBlock = &elseRegion.front();
234+
rewriter.eraseOp(elseRegion.back().getTerminator());
234235
rewriter.setInsertionPointToEnd(&elseRegion.back());
235236
rewriter.create<BranchOp>(loc, continueBlock);
236237
rewriter.inlineRegionBefore(elseRegion, continueBlock);
@@ -246,9 +247,42 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
246247
return matchSuccess();
247248
}
248249

250+
PatternMatchResult
251+
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
252+
PatternRewriter &rewriter) const {
253+
Location loc = parallelOp.getLoc();
254+
BlockAndValueMapping mapping;
255+
256+
if (parallelOp.getNumResults() != 0) {
257+
// TODO: Implement lowering of parallelOp with reductions.
258+
return matchFailure();
259+
}
260+
261+
// For a parallel loop, we essentially need to create an n-dimensional loop
262+
// nest. We do this by translating to loop.for ops and have those lowered in
263+
// a further rewrite.
264+
for (auto loop_operands :
265+
llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
266+
parallelOp.upperBound(), parallelOp.step())) {
267+
Value iv, lower, upper, step;
268+
std::tie(iv, lower, upper, step) = loop_operands;
269+
ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step);
270+
mapping.map(iv, forOp.getInductionVar());
271+
rewriter.setInsertionPointToStart(forOp.getBody());
272+
}
273+
274+
// Now copy over the contents of the body.
275+
for (auto &op : parallelOp.body().front().without_terminator())
276+
rewriter.clone(op, mapping);
277+
278+
rewriter.eraseOp(parallelOp);
279+
280+
return matchSuccess();
281+
}
282+
249283
void mlir::populateLoopToStdConversionPatterns(
250284
OwningRewritePatternList &patterns, MLIRContext *ctx) {
251-
patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
285+
patterns.insert<ForLowering, IfLowering, ParallelLowering>(ctx);
252286
}
253287

254288
void LoopToStandardPass::runOnOperation() {

mlir/test/Conversion/convert-to-cfg.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,36 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
147147
}
148148
return
149149
}
150+
151+
// CHECK-LABEL: func @parallel_loop(
152+
// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
153+
// CHECK: [[VAL_5:%.*]] = constant 1 : index
154+
// CHECK: br ^bb1([[VAL_0]] : index)
155+
// CHECK: ^bb1([[VAL_6:%.*]]: index):
156+
// CHECK: [[VAL_7:%.*]] = cmpi "slt", [[VAL_6]], [[VAL_2]] : index
157+
// CHECK: cond_br [[VAL_7]], ^bb2, ^bb6
158+
// CHECK: ^bb2:
159+
// CHECK: br ^bb3([[VAL_1]] : index)
160+
// CHECK: ^bb3([[VAL_8:%.*]]: index):
161+
// CHECK: [[VAL_9:%.*]] = cmpi "slt", [[VAL_8]], [[VAL_3]] : index
162+
// CHECK: cond_br [[VAL_9]], ^bb4, ^bb5
163+
// CHECK: ^bb4:
164+
// CHECK: [[VAL_10:%.*]] = constant 1 : index
165+
// CHECK: [[VAL_11:%.*]] = addi [[VAL_8]], [[VAL_5]] : index
166+
// CHECK: br ^bb3([[VAL_11]] : index)
167+
// CHECK: ^bb5:
168+
// CHECK: [[VAL_12:%.*]] = addi [[VAL_6]], [[VAL_4]] : index
169+
// CHECK: br ^bb1([[VAL_12]] : index)
170+
// CHECK: ^bb6:
171+
// CHECK: return
172+
// CHECK: }
173+
174+
func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
175+
%arg3 : index, %arg4 : index) {
176+
%step = constant 1 : index
177+
loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
178+
step (%arg4, %step) {
179+
%c1 = constant 1 : index
180+
}
181+
return
182+
}

0 commit comments

Comments
 (0)