diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 20889558be314..416df6e87b11f 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2247,7 +2247,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> { //===----------------------------------------------------------------------===// // Operator: cast //===----------------------------------------------------------------------===// -def Tosa_CastOp: Tosa_Op<"cast", [Pure, +def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape, DeclareOpInterfaceMethods]> { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 930bb9fe96811..b6d68c94f436b 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1304,10 +1304,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> { // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32> // CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32> // CHECK: } - %0 = "tosa.const"(){ values = dense<116.0>: tensor }: () -> tensor - %1 = "tosa.cast"(%0) : (tensor) -> tensor<3x600x1200xf32> - %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32> - return %2 : tensor<3x600x1200xf32> + %0 = "tosa.const"(){ values = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32> + %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32> + return %1 : tensor<3x600x1200xf32> } // ----- @@ -1315,10 +1314,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> { // CHECK-LABEL: @do_not_fold_reciprocal_int func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> { // CHECK: tosa.reciprocal - %0 = "tosa.const"(){ values = dense<11>: tensor }: () -> tensor - %1 = "tosa.cast"(%0) : (tensor) -> tensor<3x600x1200xi32> - %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32> - return %2 : tensor<3x600x1200xi32> + %0 = "tosa.const"(){ values = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32> + %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32> + return %1 : tensor<3x600x1200xi32> } // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 3bccb32c5b9f4..7c76c178436c7 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -6,6 +6,15 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" + +func.func @test_cast(%arg0: tensor) -> tensor<5xi32> { + // expected-error@+1{{'tosa.cast' op requires the same shape for all operands and results}} + %1 = "tosa.cast"(%arg0) : (tensor) -> tensor<5xi32> + return %1 : tensor<5xi32> +} + +// ----- + func.func @test_const() -> tensor<1xf32> { // expected-error@+1{{'tosa.const' op expected same attr/result element types}} %0 = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xf32>