-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir] Add helper to check elementwise-mappable ops with tensors and scalars #154872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Samarth Narang (snarang181) ChangesThis patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, Full diff: https://github.com/llvm/llvm-project/pull/154872.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
using namespace mlir;
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+ if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+ return true;
+ if (auto rt = dyn_cast<RankedTensorType>(t))
+ return rt.getRank() == 0; // 0-D tensors are scalar-like
+ return false;
+}
+
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
- // TODO: The conversion pattern can be made to work for `any_of` here, but
- // it's more complex as it requires tracking which operands are scalars.
- return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+ auto types = op->getOperandTypes();
+
+ // We want at least one ranked tensor.
+ bool anyRankedTensor = llvm::any_of(
+ types, [](Type type) { return isa<RankedTensorType>(type); });
+
+ // No invalid operands (i.e., every operand is a ranked tensor or
+ // scalar-like).
+ bool noneInvalid = llvm::none_of(types, [](Type t) {
+ // Invalid if neither ranked tensor nor scalar-like.
+ if (llvm::isa<RankedTensorType>(t))
+ return false;
+ if (isScalarLike(t))
+ return false;
+ return true; // Could be a memref, unranked tensor, vector, etc.
+ });
+
+ return anyRankedTensor && noneInvalid;
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
|
@llvm/pr-subscribers-mlir Author: Samarth Narang (snarang181) ChangesThis patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, Full diff: https://github.com/llvm/llvm-project/pull/154872.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
using namespace mlir;
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+ if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+ return true;
+ if (auto rt = dyn_cast<RankedTensorType>(t))
+ return rt.getRank() == 0; // 0-D tensors are scalar-like
+ return false;
+}
+
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
- // TODO: The conversion pattern can be made to work for `any_of` here, but
- // it's more complex as it requires tracking which operands are scalars.
- return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+ auto types = op->getOperandTypes();
+
+ // We want at least one ranked tensor.
+ bool anyRankedTensor = llvm::any_of(
+ types, [](Type type) { return isa<RankedTensorType>(type); });
+
+ // No invalid operands (i.e., every operand is a ranked tensor or
+ // scalar-like).
+ bool noneInvalid = llvm::none_of(types, [](Type t) {
+ // Invalid if neither ranked tensor nor scalar-like.
+ if (llvm::isa<RankedTensorType>(t))
+ return false;
+ if (isScalarLike(t))
+ return false;
+ return true; // Could be a memref, unranked tensor, vector, etc.
+ });
+
+ return anyRankedTensor && noneInvalid;
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
|
Please expand the test coverage. Should be as simple as add new cases to the existing files. |
fe6a390
to
e74a9ce
Compare
e74a9ce
to
492de32
Compare
Make check more adaptive to include broadcasting of scalars
rank-aware scalar map (d0,…,dn) -> () during lowering.
492de32
to
9009a13
Compare
@adam-smnk -- Thank you for your review, I addressed your comments. bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>); to meet the rank = 0 operand requirements / test cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Could you also add a negative test for this? |
Thanks @adam-smnk. Is this good to merge? |
Sorry, would you mind expanding a bit more? |
Of course, I take that this check You could add sth like this: func.func @negative_scalar_only_eltwise(...) {
// CHECK-NOT: linalg
arith.addf ... : f32 |
Thanks for the extra test 👍 |
This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility,
isElementwiseMappableOpOnRankedTensors
, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars.The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.