Skip to content

Commit e653d30

Browse files
author
Nicolas Vasilache
committed
[mlir][Linalg] Update ReshapeOp::build to be more idiomatic
Summary: This diff makes it easier to create a `linalg.reshape` op and adds an EDSC builder api test to exercise the new builders. Reviewers: ftynse, jpienaar Subscribers: mehdi_amini, rriddle, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72580
1 parent b4a99a0 commit e653d30

File tree

4 files changed

+82
-9
lines changed

4 files changed

+82
-9
lines changed

mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace edsc {
1717
namespace intrinsics {
1818

1919
using linalg_fill = OperationBuilder<linalg::FillOp>;
20-
using linalg_reshape = OperationBuilder<linalg::ReshapeOp>;
20+
using linalg_reshape = ValueBuilder<linalg::ReshapeOp>;
2121
using linalg_yield = OperationBuilder<linalg::YieldOp>;
2222

2323
} // namespace intrinsics

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,17 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
100100
```
101101
}];
102102

103-
let builders = [OpBuilder<
104-
"Builder *b, OperationState &result, Value view, "
105-
"ArrayAttr reassociation, ArrayRef<NamedAttribute> attrs = {}">];
103+
let builders = [
104+
// Builder for a contracting reshape whose result type is computed from
105+
// `view` and `reassociation`.
106+
OpBuilder<"Builder *b, OperationState &result, Value view, "
107+
"ArrayRef<ArrayRef<AffineExpr>> reassociation, "
108+
"ArrayRef<NamedAttribute> attrs = {}">,
109+
// Builder for a reshape whose result type is passed explicitly. This may be
110+
// either a contracting or expanding reshape.
111+
OpBuilder<"Builder *b, OperationState &result, Type resultType, Value view,"
112+
"ArrayRef<ArrayRef<AffineExpr>> reassociation, "
113+
"ArrayRef<NamedAttribute> attrs = {}">];
106114

107115
let extraClassDeclaration = [{
108116
static StringRef getReassociationAttrName() { return "reassociation"; }

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -465,14 +465,52 @@ static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
465465
[](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }, attrs);
466466
}
467467

468-
void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result,
469-
Value view, ArrayAttr reassociation,
470-
ArrayRef<NamedAttribute> attrs) {
471-
auto maps = getAffineMaps(reassociation);
468+
template <typename AffineExprTy>
469+
unsigned getMaxPosOfType(ArrayRef<ArrayRef<AffineExpr>> exprArrays) {
470+
unsigned pos = 0;
471+
for (auto exprs : exprArrays) {
472+
for (auto expr : exprs) {
473+
expr.walk([&pos](AffineExpr e) {
474+
if (auto d = e.dyn_cast<AffineExprTy>())
475+
pos = std::max(pos, d.getPosition());
476+
});
477+
}
478+
}
479+
return pos;
480+
}
481+
482+
static SmallVector<AffineMap, 4>
483+
getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
484+
unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
485+
unsigned maxSym = getMaxPosOfType<AffineSymbolExpr>(reassociation);
486+
assert(maxSym == 0 && "Expected symbol-less expressions");
487+
SmallVector<AffineMap, 4> maps;
488+
maps.reserve(reassociation.size());
489+
for (auto exprs : reassociation)
490+
maps.push_back(AffineMap::get(maxDim + 1, 0, exprs));
491+
return maps;
492+
}
493+
494+
void mlir::linalg::ReshapeOp::build(
495+
Builder *b, OperationState &result, Value view,
496+
ArrayRef<ArrayRef<AffineExpr>> reassociation,
497+
ArrayRef<NamedAttribute> attrs) {
498+
auto maps = getSymbolLessAffineMaps(reassociation);
472499
auto memRefType = view.getType().cast<MemRefType>();
473500
auto resultType = computeReshapeCollapsedType(memRefType, maps);
474501
build(b, result, resultType, view, attrs);
475-
result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation);
502+
result.addAttribute(ReshapeOp::getReassociationAttrName(),
503+
b->getAffineMapArrayAttr(maps));
504+
}
505+
506+
void mlir::linalg::ReshapeOp::build(
507+
Builder *b, OperationState &result, Type resultType, Value view,
508+
ArrayRef<ArrayRef<AffineExpr>> reassociation,
509+
ArrayRef<NamedAttribute> attrs) {
510+
auto maps = getSymbolLessAffineMaps(reassociation);
511+
build(b, result, resultType, view, attrs);
512+
result.addAttribute(ReshapeOp::getReassociationAttrName(),
513+
b->getAffineMapArrayAttr(maps));
476514
}
477515

478516
static void print(OpAsmPrinter &p, ReshapeOp op) {

mlir/test/EDSC/builder-api-test.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/AffineOps/AffineOps.h"
1212
#include "mlir/Dialect/Linalg/EDSC/Builders.h"
13+
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
1314
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
1415
#include "mlir/Dialect/StandardOps/Ops.h"
1516
#include "mlir/EDSC/Builders.h"
@@ -962,6 +963,32 @@ TEST_FUNC(linalg_dilated_conv_nhwc) {
962963
f.erase();
963964
}
964965

966+
// clang-format off
967+
// CHECK-LABEL: func @linalg_metadata_ops
968+
// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<4x8x16xf32> into memref<32x16xf32>
969+
// CHECK: linalg.reshape {{.*}} [(d0, d1, d2) -> (d0, d1), (d0, d1, d2) -> (d2)] : memref<32x16xf32> into memref<4x8x16xf32>
970+
// clang-format on
971+
TEST_FUNC(linalg_metadata_ops) {
972+
using namespace edsc;
973+
using namespace edsc::intrinsics;
974+
975+
auto f32Type = FloatType::getF32(&globalContext());
976+
auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0);
977+
auto f = makeFunction("linalg_metadata_ops", {}, {memrefType});
978+
979+
OpBuilder builder(f.getBody());
980+
ScopedContext scope(builder, f.getLoc());
981+
AffineExpr i, j, k;
982+
bindDims(&globalContext(), i, j, k);
983+
ValueHandle v(f.getArgument(0));
984+
auto reshaped = linalg_reshape(v, ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
985+
linalg_reshape(memrefType, reshaped,
986+
ArrayRef<ArrayRef<AffineExpr>>{{i, j}, k});
987+
988+
f.print(llvm::outs());
989+
f.erase();
990+
}
991+
965992
int main() {
966993
RUN_TESTS();
967994
return 0;

0 commit comments

Comments
 (0)