Skip to content

Conversation

jorickert
Copy link
Contributor

@jorickert jorickert commented Aug 15, 2025

According to the TOSA spec, tosa.cast is only changing the elementtype, and not the shape of the input tensor

@llvmbot
Copy link
Member

llvmbot commented Aug 15, 2025

@llvm/pr-subscribers-mlir

Author: Jonas Rickert (jorickert)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/153826.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+6-8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..416df6e87b11f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2247,7 +2247,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
 //===----------------------------------------------------------------------===//
 // Operator: cast
 //===----------------------------------------------------------------------===//
-def Tosa_CastOp: Tosa_Op<"cast", [Pure,
+def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                               ["inferReturnTypeComponents"]>]> {
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 930bb9fe96811..b6d68c94f436b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1304,10 +1304,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
   // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{values = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
   // CHECK:           return %[[VAL_0]] : tensor<3x600x1200xf32>
   // CHECK:         }
-  %0 = "tosa.const"(){ values = dense<116.0>: tensor<f32> }: () -> tensor<f32>
-  %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
-  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
-  return %2 : tensor<3x600x1200xf32>
+  %0 = "tosa.const"(){ values = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32>
+  %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+  return %1 : tensor<3x600x1200xf32>
 }
 
 // -----
@@ -1315,10 +1314,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
 // CHECK-LABEL: @do_not_fold_reciprocal_int
 func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
   // CHECK:           tosa.reciprocal
-  %0 = "tosa.const"(){ values = dense<11>: tensor<i32> }: () -> tensor<i32>
-  %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
-  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
-  return %2 : tensor<3x600x1200xi32>
+  %0 = "tosa.const"(){ values = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32>
+  %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
+  return %1 : tensor<3x600x1200xi32>
 }
 
 // -----

@llvmbot
Copy link
Member

llvmbot commented Aug 15, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Jonas Rickert (jorickert)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/153826.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+6-8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 20889558be314..416df6e87b11f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2247,7 +2247,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
 //===----------------------------------------------------------------------===//
 // Operator: cast
 //===----------------------------------------------------------------------===//
-def Tosa_CastOp: Tosa_Op<"cast", [Pure,
+def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
                               ["inferReturnTypeComponents"]>]> {
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 930bb9fe96811..b6d68c94f436b 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1304,10 +1304,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
   // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{values = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32>
   // CHECK:           return %[[VAL_0]] : tensor<3x600x1200xf32>
   // CHECK:         }
-  %0 = "tosa.const"(){ values = dense<116.0>: tensor<f32> }: () -> tensor<f32>
-  %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<3x600x1200xf32>
-  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
-  return %2 : tensor<3x600x1200xf32>
+  %0 = "tosa.const"(){ values = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32>
+  %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32>
+  return %1 : tensor<3x600x1200xf32>
 }
 
 // -----
@@ -1315,10 +1314,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> {
 // CHECK-LABEL: @do_not_fold_reciprocal_int
 func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
   // CHECK:           tosa.reciprocal
-  %0 = "tosa.const"(){ values = dense<11>: tensor<i32> }: () -> tensor<i32>
-  %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<3x600x1200xi32>
-  %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
-  return %2 : tensor<3x600x1200xi32>
+  %0 = "tosa.const"(){ values = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32>
+  %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
+  return %1 : tensor<3x600x1200xi32>
 }
 
 // -----

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @jorickert! Would it be possible to add a negative test case under invalid.mlir as well?

Signed-off-by: Rickert, Jonas <jonas.rickert@amd.com>
@jorickert jorickert force-pushed the jrickert.upstream.tosa_cast_same_shape branch from e703175 to 23bdbe5 Compare August 18, 2025 06:50
@jorickert
Copy link
Contributor Author

Thanks for the fix @jorickert! Would it be possible to add a negative test case under invalid.mlir as well?

Thanks for the suggestion, I added a test that checks for the error message

@jorickert jorickert requested a review from lhutton1 August 18, 2025 06:53
@jorickert
Copy link
Contributor Author

@lhutton1 Could you please review this PR or assign it to someone else? Thank you

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for missing this, was away last week, LGTM!

@lhutton1 lhutton1 merged commit 2191f5a into llvm:main Aug 26, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants