Skip to content

Commit 5c352e6

Browse files
Ehsan Toosidfki-mako
Ehsan Toosi
authored andcommitted
Providing buffer assignment for MLIR
We have provided a generic buffer assignment transformation ported from TensorFlow. This generic transformation pass automatically analyzes the values and their aliases (also in other blocks) and returns the valid positions for Alloc and Dealloc operations. To find these positions, the algorithm uses the block Dominator and Post-Dominator analyses. In our proposed algorithm, we have considered aliasing, liveness, nested regions, branches, conditional branches, critical edges, and independency to custom block terminators. This implementation doesn't support block loops. However, we have considered this in our design. For this purpose, it is only required to have a loop analysis to insert Alloc and Dealloc operations outside of these loops in some special cases. Differential Revision: https://reviews.llvm.org/D78484
1 parent f03b505 commit 5c352e6

File tree

10 files changed

+1371
-0
lines changed

10 files changed

+1371
-0
lines changed
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
//===- BufferPlacement.h - Buffer Assignment Utilities ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines buffer assignment helper methods to compute correct
10+
// and valid positions for placing Alloc and Dealloc operations.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_TRANSFORMS_BUFFERPLACEMENT_H
15+
#define MLIR_TRANSFORMS_BUFFERPLACEMENT_H
16+
17+
#include "mlir/Analysis/Dominance.h"
18+
#include "mlir/Analysis/Liveness.h"
19+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
20+
#include "mlir/IR/Builders.h"
21+
#include "mlir/IR/Operation.h"
22+
#include "mlir/Transforms/DialectConversion.h"
23+
24+
namespace mlir {
25+
26+
/// Prepares a buffer placement phase. It can place (user-defined) alloc
27+
/// nodes. This simplifies the integration of the actual buffer-placement
28+
/// pass. Sample usage:
29+
/// BufferAssignmentPlacer baHelper(regionOp);
30+
/// -> determine alloc positions
31+
/// auto allocPosition = baHelper.computeAllocPosition(value);
32+
/// -> place alloc
33+
/// allocBuilder.setInsertionPoint(positions.getAllocPosition());
34+
/// <create alloc>
35+
/// Note: this class is intended to be used during legalization. In order
36+
/// to move alloc and dealloc nodes into the right places you can use the
37+
/// createBufferPlacementPass() function.
38+
class BufferAssignmentPlacer {
39+
public:
40+
/// Creates a new assignment builder.
41+
explicit BufferAssignmentPlacer(Operation *op);
42+
43+
/// Returns the operation this analysis was constructed from.
44+
Operation *getOperation() const { return operation; }
45+
46+
/// Computes the actual position to place allocs for the given result.
47+
OpBuilder::InsertPoint computeAllocPosition(OpResult result);
48+
49+
private:
50+
/// The operation this analysis was constructed from.
51+
Operation *operation;
52+
};
53+
54+
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
55+
/// instance. Sample usage:
56+
/// class CustomConversionPattern : public
57+
/// BufferAssignmentOpConversionPattern<MyOpT>
58+
/// {
59+
/// ... matchAndRewrite(...) {
60+
/// -> Access stored BufferAssignmentPlacer
61+
/// bufferAssignment->computeAllocPosition(resultOp);
62+
/// }
63+
/// };
64+
template <typename SourceOp>
65+
class BufferAssignmentOpConversionPattern
66+
: public OpConversionPattern<SourceOp> {
67+
public:
68+
explicit BufferAssignmentOpConversionPattern(
69+
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr,
70+
TypeConverter *converter = nullptr, PatternBenefit benefit = 1)
71+
: OpConversionPattern<SourceOp>(context, benefit),
72+
bufferAssignment(bufferAssignment), converter(converter) {}
73+
74+
protected:
75+
BufferAssignmentPlacer *bufferAssignment;
76+
TypeConverter *converter;
77+
};
78+
79+
/// This conversion adds an extra argument for each function result which makes
80+
/// the converted function a void function. A type converter must be provided
81+
/// for this conversion to convert a non-shaped type to memref.
82+
/// BufferAssignmentTypeConverter is an helper TypeConverter for this
83+
/// purpose. All the non-shaped type of the input function will be converted to
84+
/// memref.
85+
class FunctionAndBlockSignatureConverter
86+
: public BufferAssignmentOpConversionPattern<FuncOp> {
87+
public:
88+
using BufferAssignmentOpConversionPattern<
89+
FuncOp>::BufferAssignmentOpConversionPattern;
90+
91+
/// Performs the actual signature rewriting step.
92+
LogicalResult
93+
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
94+
ConversionPatternRewriter &rewriter) const final;
95+
};
96+
97+
/// This pattern converter transforms a non-void ReturnOpSourceTy into a void
98+
/// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy
99+
/// to copy the results to the output buffer.
100+
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
101+
typename CopyOpTy>
102+
class NonVoidToVoidReturnOpConverter
103+
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
104+
public:
105+
using BufferAssignmentOpConversionPattern<
106+
ReturnOpSourceTy>::BufferAssignmentOpConversionPattern;
107+
108+
/// Performs the actual return-op conversion step.
109+
LogicalResult
110+
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
111+
ConversionPatternRewriter &rewriter) const final {
112+
unsigned numReturnValues = returnOp.getNumOperands();
113+
Block &entryBlock = returnOp.getParentRegion()->front();
114+
unsigned numFuncArgs = entryBlock.getNumArguments();
115+
Location loc = returnOp.getLoc();
116+
117+
// Find the corresponding output buffer for each operand.
118+
assert(numReturnValues <= numFuncArgs &&
119+
"The number of operands of return operation is more than the "
120+
"number of function argument.");
121+
unsigned firstReturnParameter = numFuncArgs - numReturnValues;
122+
for (auto operand : llvm::enumerate(operands)) {
123+
unsigned returnArgNumber = firstReturnParameter + operand.index();
124+
BlockArgument dstBuffer = entryBlock.getArgument(returnArgNumber);
125+
if (dstBuffer == operand.value())
126+
continue;
127+
128+
// Insert the copy operation to copy before the return.
129+
rewriter.setInsertionPoint(returnOp);
130+
rewriter.create<CopyOpTy>(loc, operand.value(),
131+
entryBlock.getArgument(returnArgNumber));
132+
}
133+
// Insert the new target return operation.
134+
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp);
135+
return success();
136+
}
137+
};
138+
139+
/// A helper type converter class for using inside Buffer Assignment operation
140+
/// conversion patterns. The default constructor keeps all the types intact
141+
/// except for the ranked-tensor types which is converted to memref types.
142+
class BufferAssignmentTypeConverter : public TypeConverter {
143+
public:
144+
BufferAssignmentTypeConverter();
145+
};
146+
147+
} // end namespace mlir
148+
149+
#endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H

mlir/include/mlir/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class ModuleOp;
2626
class Pass;
2727
template <typename T> class OperationPass;
2828

29+
/// Creates an instance of the BufferPlacement pass.
30+
std::unique_ptr<Pass> createBufferPlacementPass();
31+
2932
/// Creates an instance of the Canonicalizer pass.
3033
std::unique_ptr<Pass> createCanonicalizerPass();
3134

mlir/include/mlir/Transforms/Passes.td

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,68 @@ def AffinePipelineDataTransfer
102102
let constructor = "mlir::createPipelineDataTransferPass()";
103103
}
104104

105+
def BufferPlacement : Pass<"buffer-placement"> {
106+
let summary = "Optimizes placement of alloc and dealloc operations";
107+
let description = [{
108+
This pass implements an algorithm to optimize the placement of alloc and
109+
dealloc operations. This pass also inserts missing dealloc operations
110+
automatically to reclaim memory.
111+
112+
113+
Input
114+
115+
```mlir
116+
#map0 = affine_map<(d0) -> (d0)>
117+
module {
118+
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
119+
cond_br %arg0, ^bb1, ^bb2
120+
^bb1:
121+
br ^bb3(%arg1 : memref<2xf32>)
122+
^bb2:
123+
%0 = alloc() : memref<2xf32>
124+
linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 {
125+
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
126+
%tmp1 = exp %gen1_arg0 : f32
127+
linalg.yield %tmp1 : f32
128+
}: memref<2xf32>, memref<2xf32>
129+
br ^bb3(%0 : memref<2xf32>)
130+
^bb3(%1: memref<2xf32>):
131+
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
132+
return
133+
}
134+
}
135+
136+
```
137+
138+
Output
139+
140+
```mlir
141+
#map0 = affine_map<(d0) -> (d0)>
142+
module {
143+
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
144+
%0 = alloc() : memref<2xf32>
145+
cond_br %arg0, ^bb1, ^bb2
146+
^bb1: // pred: ^bb0
147+
br ^bb3(%arg1 : memref<2xf32>)
148+
^bb2: // pred: ^bb0
149+
linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg1, %0 {
150+
^bb0(%arg3: f32, %arg4: f32): // no predecessors
151+
%2 = exp %arg3 : f32
152+
linalg.yield %2 : f32
153+
}: memref<2xf32>, memref<2xf32>
154+
br ^bb3(%0 : memref<2xf32>)
155+
^bb3(%1: memref<2xf32>): // 2 preds: ^bb1, ^bb2
156+
linalg.copy(%1, %arg2) : memref<2xf32>, memref<2xf32>
157+
dealloc %0 : memref<2xf32>
158+
return
159+
}
160+
}
161+
```
162+
163+
}];
164+
let constructor = "mlir::createBufferPlacementPass()";
165+
}
166+
105167
def Canonicalizer : Pass<"canonicalize"> {
106168
let summary = "Canonicalize operations";
107169
let description = [{

0 commit comments

Comments
 (0)