Skip to content

Conversation

snarang181
Copy link
Contributor

This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, isElementwiseMappableOpOnRankedTensors, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars.
The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.

@snarang181 snarang181 marked this pull request as ready for review August 22, 2025 02:10
@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Samarth Narang (snarang181)

Changes

This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, isElementwiseMappableOpOnRankedTensors, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars.
The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.


Full diff: https://github.com/llvm/llvm-project/pull/154872.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp (+27-3)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
 
 using namespace mlir;
 
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+  if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+    return true;
+  if (auto rt = dyn_cast<RankedTensorType>(t))
+    return rt.getRank() == 0; // 0-D tensors are scalar-like
+  return false;
+}
+
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
-  // TODO: The conversion pattern can be made to work for `any_of` here, but
-  // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+  auto types = op->getOperandTypes();
+
+  // We want at least one ranked tensor.
+  bool anyRankedTensor = llvm::any_of(
+      types, [](Type type) { return isa<RankedTensorType>(type); });
+
+  // No invalid operands (i.e., every operand is a ranked tensor or
+  // scalar-like).
+  bool noneInvalid = llvm::none_of(types, [](Type t) {
+    // Invalid if neither ranked tensor nor scalar-like.
+    if (llvm::isa<RankedTensorType>(t))
+      return false;
+    if (isScalarLike(t))
+      return false;
+    return true; // Could be a memref, unranked tensor, vector, etc.
+  });
+
+  return anyRankedTensor && noneInvalid;
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over

@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2025

@llvm/pr-subscribers-mlir

Author: Samarth Narang (snarang181)

Changes

This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, isElementwiseMappableOpOnRankedTensors, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars.
The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.


Full diff: https://github.com/llvm/llvm-project/pull/154872.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp (+27-3)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
 
 using namespace mlir;
 
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+  if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+    return true;
+  if (auto rt = dyn_cast<RankedTensorType>(t))
+    return rt.getRank() == 0; // 0-D tensors are scalar-like
+  return false;
+}
+
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
-  // TODO: The conversion pattern can be made to work for `any_of` here, but
-  // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+  auto types = op->getOperandTypes();
+
+  // We want at least one ranked tensor.
+  bool anyRankedTensor = llvm::any_of(
+      types, [](Type type) { return isa<RankedTensorType>(type); });
+
+  // No invalid operands (i.e., every operand is a ranked tensor or
+  // scalar-like).
+  bool noneInvalid = llvm::none_of(types, [](Type t) {
+    // Invalid if neither ranked tensor nor scalar-like.
+    if (llvm::isa<RankedTensorType>(t))
+      return false;
+    if (isScalarLike(t))
+      return false;
+    return true; // Could be a memref, unranked tensor, vector, etc.
+  });
+
+  return anyRankedTensor && noneInvalid;
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over

@rengolin
Copy link
Member

Please expand the test coverage. Should be as simple as add new cases to the existing files.

Make check more adaptive to include broadcasting of scalars
rank-aware scalar map (d0,…,dn) -> () during lowering.
@snarang181 snarang181 force-pushed the elemwise_linalg_opt branch from 492de32 to 9009a13 Compare August 25, 2025 16:11
@snarang181
Copy link
Contributor Author

@adam-smnk -- Thank you for your review, I addressed your comments.
I still have

 bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);

to meet the rank = 0 operand requirements / test cases.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@adam-smnk
Copy link
Contributor

to meet the rank = 0 operand requirements / test cases.

Could you also add a negative test for this?

@snarang181
Copy link
Contributor Author

LGTM

Thanks @adam-smnk. Is this good to merge?

@snarang181
Copy link
Contributor Author

to meet the rank = 0 operand requirements / test cases.

Could you also add a negative test for this?

Sorry, would you mind expanding a bit more?

@adam-smnk
Copy link
Contributor

adam-smnk commented Aug 25, 2025

to meet the rank = 0 operand requirements / test cases.

Could you also add a negative test for this?

Sorry, would you mind expanding a bit more?

Of course, I take that this check anyRankedTensor guards against scalar-only eltwise op as linalg requires a shaped type for its outs types. It'd be great to have a one test case that covers this scenario and ensures that the pattern doesn't match (and nothing blows up 😉).

You could add sth like this:

func.func @negative_scalar_only_eltwise(...) {
// CHECK-NOT: linalg
arith.addf ... : f32

@adam-smnk
Copy link
Contributor

Thanks for the extra test 👍
Should be good to merge

@snarang181 snarang181 merged commit 9801a0f into llvm:main Aug 25, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants