14
14
#include " mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
15
15
#include " mlir/Dialect/LoopOps/LoopOps.h"
16
16
#include " mlir/Dialect/StandardOps/Ops.h"
17
+ #include " mlir/IR/BlockAndValueMapping.h"
17
18
#include " mlir/IR/Builders.h"
18
19
#include " mlir/IR/MLIRContext.h"
19
20
#include " mlir/IR/Module.h"
@@ -142,14 +143,11 @@ struct IfLowering : public OpRewritePattern<IfOp> {
142
143
PatternRewriter &rewriter) const override ;
143
144
};
144
145
145
- struct TerminatorLowering : public OpRewritePattern <TerminatorOp > {
146
- using OpRewritePattern<TerminatorOp >::OpRewritePattern;
146
+ struct ParallelLowering : public OpRewritePattern <mlir::loop::ParallelOp > {
147
+ using OpRewritePattern<mlir::loop::ParallelOp >::OpRewritePattern;
147
148
148
- PatternMatchResult matchAndRewrite (TerminatorOp op,
149
- PatternRewriter &rewriter) const override {
150
- rewriter.eraseOp (op);
151
- return matchSuccess ();
152
- }
149
+ PatternMatchResult matchAndRewrite (mlir::loop::ParallelOp parallelOp,
150
+ PatternRewriter &rewriter) const override ;
153
151
};
154
152
} // namespace
155
153
@@ -178,6 +176,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
178
176
// Append the induction variable stepping logic to the last body block and
179
177
// branch back to the condition block. Construct an expression f :
180
178
// (x -> x+step) and apply this expression to the induction variable.
179
+ rewriter.eraseOp (lastBodyBlock->getTerminator ());
181
180
rewriter.setInsertionPointToEnd (lastBodyBlock);
182
181
auto step = forOp.step ();
183
182
auto stepped = rewriter.create <AddIOp>(loc, iv, step).getResult ();
@@ -220,6 +219,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
220
219
// place it before the continuation block, and branch to it.
221
220
auto &thenRegion = ifOp.thenRegion ();
222
221
auto *thenBlock = &thenRegion.front ();
222
+ rewriter.eraseOp (thenRegion.back ().getTerminator ());
223
223
rewriter.setInsertionPointToEnd (&thenRegion.back ());
224
224
rewriter.create <BranchOp>(loc, continueBlock);
225
225
rewriter.inlineRegionBefore (thenRegion, continueBlock);
@@ -231,6 +231,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
231
231
auto &elseRegion = ifOp.elseRegion ();
232
232
if (!elseRegion.empty ()) {
233
233
elseBlock = &elseRegion.front ();
234
+ rewriter.eraseOp (elseRegion.back ().getTerminator ());
234
235
rewriter.setInsertionPointToEnd (&elseRegion.back ());
235
236
rewriter.create <BranchOp>(loc, continueBlock);
236
237
rewriter.inlineRegionBefore (elseRegion, continueBlock);
@@ -246,9 +247,42 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
246
247
return matchSuccess ();
247
248
}
248
249
250
+ PatternMatchResult
251
+ ParallelLowering::matchAndRewrite (ParallelOp parallelOp,
252
+ PatternRewriter &rewriter) const {
253
+ Location loc = parallelOp.getLoc ();
254
+ BlockAndValueMapping mapping;
255
+
256
+ if (parallelOp.getNumResults () != 0 ) {
257
+ // TODO: Implement lowering of parallelOp with reductions.
258
+ return matchFailure ();
259
+ }
260
+
261
+ // For a parallel loop, we essentially need to create an n-dimensional loop
262
+ // nest. We do this by translating to loop.for ops and have those lowered in
263
+ // a further rewrite.
264
+ for (auto loop_operands :
265
+ llvm::zip (parallelOp.getInductionVars (), parallelOp.lowerBound (),
266
+ parallelOp.upperBound (), parallelOp.step ())) {
267
+ Value iv, lower, upper, step;
268
+ std::tie (iv, lower, upper, step) = loop_operands;
269
+ ForOp forOp = rewriter.create <ForOp>(loc, lower, upper, step);
270
+ mapping.map (iv, forOp.getInductionVar ());
271
+ rewriter.setInsertionPointToStart (forOp.getBody ());
272
+ }
273
+
274
+ // Now copy over the contents of the body.
275
+ for (auto &op : parallelOp.body ().front ().without_terminator ())
276
+ rewriter.clone (op, mapping);
277
+
278
+ rewriter.eraseOp (parallelOp);
279
+
280
+ return matchSuccess ();
281
+ }
282
+
249
283
void mlir::populateLoopToStdConversionPatterns (
250
284
OwningRewritePatternList &patterns, MLIRContext *ctx) {
251
- patterns.insert <ForLowering, IfLowering, TerminatorLowering >(ctx);
285
+ patterns.insert <ForLowering, IfLowering, ParallelLowering >(ctx);
252
286
}
253
287
254
288
void LoopToStandardPass::runOnOperation () {
0 commit comments