-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][TOSA] Add missing SameOperandsAndResultShape Trait to tosa.cast #153826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][TOSA] Add missing SameOperandsAndResultShape Trait to tosa.cast #153826
Conversation
@llvm/pr-subscribers-mlir Author: Jonas Rickert (jorickert) ChangesFull diff: https://github.com/llvm/llvm-project/pull/153826.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..416df6e87b11f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2247,7 +2247,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
//===----------------------------------------------------------------------===//
// Operator: cast
//===----------------------------------------------------------------------===//
-def Tosa_CastOp: Tosa_Op<"cast", [Pure,
+def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>]> {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 930bb9fe96811..b6d68c94f436b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1304,10 +1304,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
// CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
// CHECK: }
- %0 = "tosa.const"(){ values = dense<116.0>: tensor<f32> }: () -> tensor<f32>
- %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
- %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
- return %2 : tensor<3x600x1200xf32>
+ %0 = "tosa.const"(){ values = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32>
+ %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+ return %1 : tensor<3x600x1200xf32>
}
// -----
@@ -1315,10 +1314,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
// CHECK-LABEL: @do_not_fold_reciprocal_int
func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
// CHECK: tosa.reciprocal
- %0 = "tosa.const"(){ values = dense<11>: tensor<i32> }: () -> tensor<i32>
- %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
- %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
- return %2 : tensor<3x600x1200xi32>
+ %0 = "tosa.const"(){ values = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32>
+ %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
+ return %1 : tensor<3x600x1200xi32>
}
// -----
|
@llvm/pr-subscribers-mlir-tosa Author: Jonas Rickert (jorickert) ChangesFull diff: https://github.com/llvm/llvm-project/pull/153826.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..416df6e87b11f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2247,7 +2247,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
//===----------------------------------------------------------------------===//
// Operator: cast
//===----------------------------------------------------------------------===//
-def Tosa_CastOp: Tosa_Op<"cast", [Pure,
+def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>]> {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 930bb9fe96811..b6d68c94f436b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1304,10 +1304,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
// CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32>
// CHECK: }
- %0 = "tosa.const"(){ values = dense<116.0>: tensor<f32> }: () -> tensor<f32>
- %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
- %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
- return %2 : tensor<3x600x1200xf32>
+ %0 = "tosa.const"(){ values = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32>
+ %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+ return %1 : tensor<3x600x1200xf32>
}
// -----
@@ -1315,10 +1314,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
// CHECK-LABEL: @do_not_fold_reciprocal_int
func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
// CHECK: tosa.reciprocal
- %0 = "tosa.const"(){ values = dense<11>: tensor<i32> }: () -> tensor<i32>
- %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
- %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
- return %2 : tensor<3x600x1200xi32>
+ %0 = "tosa.const"(){ values = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32>
+ %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
+ return %1 : tensor<3x600x1200xi32>
}
// -----
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix @jorickert! Would it be possible to add a negative test case under invalid.mlir as well?
Signed-off-by: Rickert, Jonas <jonas.rickert@amd.com>
e703175
to
23bdbe5
Compare
Thanks for the suggestion, I added a test that checks for the error message |
@lhutton1 Could you please review this PR or assign it to someone else? Thank you |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for missing this, was away last week, LGTM!
According to the TOSA spec, tosa.cast is only changing the elementtype, and not the shape of the input tensor