Skip to content

Commit ce674b1

Browse files
committed
[mlir] Add support for marking 'unknown' operations as dynamically legal.
Summary: This allows for providing a default "catchall" legality check that is not dependent on specific operations or dialects. For example, this can be useful to check legality based on the specific types of operation operands or results. Differential Revision: https://reviews.llvm.org/D73379
1 parent 4953213 commit ce674b1

File tree

5 files changed

+66
-22
lines changed

5 files changed

+66
-22
lines changed

mlir/docs/DialectConversion.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ struct MyTarget : public ConversionTarget {
100100
/// callback.
101101
addDynamicallyLegalOp<ReturnOp>([](ReturnOp op) { ... });
102102

103+
/// Treat unknown operations, i.e. those without a legalization action
104+
/// directly set, as dynamically legal.
105+
markUnknownOpDynamicallyLegal();
106+
markUnknownOpDynamicallyLegal([](Operation *op) { ... });
107+
103108
//--------------------------------------------------------------------------
104109
// Marking an operation as illegal.
105110

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ class ConversionTarget {
416416
/// dynamically legal on the target.
417417
using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
418418

419-
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
419+
ConversionTarget(MLIRContext &ctx)
420+
: unknownOpsDynamicallyLegal(false), ctx(ctx) {}
420421
virtual ~ConversionTarget() = default;
421422

422423
//===--------------------------------------------------------------------===//
@@ -532,6 +533,16 @@ class ConversionTarget {
532533
setLegalityCallback(dialectNames, *callback);
533534
}
534535

536+
/// Register unknown operations as dynamically legal. For operations(and
537+
/// dialects) that do not have a set legalization action, treat them as
538+
/// dynamically legal and invoke the given callback if valid or
539+
/// 'isDynamicallyLegal'.
540+
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
541+
unknownOpsDynamicallyLegal = true;
542+
unknownLegalityFn = fn;
543+
}
544+
void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; }
545+
535546
/// Register the operations of the given dialects as illegal, i.e.
536547
/// operations of this dialect are not supported by the target.
537548
template <typename... Names>
@@ -585,6 +596,9 @@ class ConversionTarget {
585596

586597
/// If some legal instances of this operation may also be recursively legal.
587598
bool isRecursivelyLegal;
599+
600+
/// The legality callback if this operation is dynamically legal.
601+
Optional<DynamicLegalityCallbackFn> legalityFn;
588602
};
589603

590604
/// Get the legalization information for the given operation.
@@ -594,9 +608,6 @@ class ConversionTarget {
594608
/// information.
595609
llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
596610

597-
/// A set of dynamic legality callbacks for given operation names.
598-
DenseMap<OperationName, DynamicLegalityCallbackFn> opLegalityFns;
599-
600611
/// A set of legality callbacks for given operation names that are used to
601612
/// check if an operation instance is recursively legal.
602613
DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
@@ -608,6 +619,13 @@ class ConversionTarget {
608619
/// A set of dynamic legality callbacks for given dialect names.
609620
llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
610621

622+
/// An optional legality callback for unknown operations.
623+
Optional<DynamicLegalityCallbackFn> unknownLegalityFn;
624+
625+
/// Flag indicating if unknown operations should be treated as dynamically
626+
/// legal.
627+
bool unknownOpsDynamicallyLegal;
628+
611629
/// The current context this target applies to.
612630
MLIRContext &ctx;
613631
};

mlir/lib/Transforms/DialectConversion.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,19 +1704,11 @@ auto ConversionTarget::isLegal(Operation *op) const
17041704

17051705
// Returns true if this operation instance is known to be legal.
17061706
auto isOpLegal = [&] {
1707-
// Handle dynamic legality.
1708-
if (info->action == LegalizationAction::Dynamic) {
1709-
// Check for callbacks on the operation or dialect.
1710-
auto opFn = opLegalityFns.find(op->getName());
1711-
if (opFn != opLegalityFns.end())
1712-
return opFn->second(op);
1713-
auto dialectFn = dialectLegalityFns.find(op->getName().getDialect());
1714-
if (dialectFn != dialectLegalityFns.end())
1715-
return dialectFn->second(op);
1716-
1717-
// Otherwise, invoke the hook on the derived instance.
1718-
return isDynamicallyLegal(op);
1719-
}
1707+
// Handle dynamic legality either with the provided legality function, or
1708+
// the default hook on the derived instance.
1709+
if (info->action == LegalizationAction::Dynamic)
1710+
return info->legalityFn ? (*info->legalityFn)(op)
1711+
: isDynamicallyLegal(op);
17201712

17211713
// Otherwise, the operation is only legal if it was marked 'Legal'.
17221714
return info->action == LegalizationAction::Legal;
@@ -1726,7 +1718,6 @@ auto ConversionTarget::isLegal(Operation *op) const
17261718

17271719
// This operation is legal, compute any additional legality information.
17281720
LegalOpDetails legalityDetails;
1729-
17301721
if (info->isRecursivelyLegal) {
17311722
auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
17321723
if (legalityFnIt != opRecursiveLegalityFns.end())
@@ -1741,7 +1732,11 @@ auto ConversionTarget::isLegal(Operation *op) const
17411732
void ConversionTarget::setLegalityCallback(
17421733
OperationName name, const DynamicLegalityCallbackFn &callback) {
17431734
assert(callback && "expected valid legality callback");
1744-
opLegalityFns[name] = callback;
1735+
auto infoIt = legalOperations.find(name);
1736+
assert(infoIt != legalOperations.end() &&
1737+
infoIt->second.action == LegalizationAction::Dynamic &&
1738+
"expected operation to already be marked as dynamically legal");
1739+
infoIt->second.legalityFn = callback;
17451740
}
17461741

17471742
/// Set the recursive legality callback for the given operation and mark the
@@ -1774,10 +1769,20 @@ auto ConversionTarget::getOpInfo(OperationName op) const
17741769
auto it = legalOperations.find(op);
17751770
if (it != legalOperations.end())
17761771
return it->second;
1777-
// Otherwise, default to checking on the parent dialect.
1772+
// Check for info for the parent dialect.
17781773
auto dialectIt = legalDialects.find(op.getDialect());
1779-
if (dialectIt != legalDialects.end())
1780-
return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false};
1774+
if (dialectIt != legalDialects.end()) {
1775+
Optional<DynamicLegalityCallbackFn> callback;
1776+
auto dialectFn = dialectLegalityFns.find(op.getDialect());
1777+
if (dialectFn != dialectLegalityFns.end())
1778+
callback = dialectFn->second;
1779+
return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
1780+
callback};
1781+
}
1782+
// Otherwise, check if we mark unknown operations as dynamic.
1783+
if (unknownOpsDynamicallyLegal)
1784+
return LegalizationInfo{LegalizationAction::Dynamic,
1785+
/*isRecursivelyLegal=*/false, unknownLegalityFn};
17811786
return llvm::None;
17821787
}
17831788

mlir/test/Transforms/test-legalizer-full.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,14 @@ func @test_undo_region_clone() {
5858
%ignored = "test.illegal_op_f"() : () -> (i32)
5959
"test.return"() : () -> ()
6060
}
61+
62+
// -----
63+
64+
// Test that unknown operations can be dynamically legal.
65+
func @test_unknown_dynamically_legal() {
66+
"foo.unknown_op"() {test.dynamically_legal} : () -> ()
67+
68+
// expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}}
69+
"foo.unknown_op"() {} : () -> ()
70+
"test.return"() : () -> ()
71+
}

mlir/test/lib/TestDialect/TestPatterns.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,11 @@ struct TestLegalizePatternDriver
399399

400400
// Handle a full conversion.
401401
if (mode == ConversionMode::Full) {
402+
// Check support for marking unknown operations as dynamically legal.
403+
target.markUnknownOpDynamicallyLegal([](Operation *op) {
404+
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
405+
});
406+
402407
(void)applyFullConversion(getModule(), target, patterns, &converter);
403408
return;
404409
}

0 commit comments

Comments
 (0)