Skip to content

Conversation

ShivaChen
Copy link
Collaborator

The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.

The shift, multiplier, inputZp, and outputZp can be either constant or
non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to
linalg::GenericOp, and corresponding affine maps are appended to the
indexingMaps.
@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (ShivaChen)

Changes

The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.


Patch is 20.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155967.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+249-96)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+28)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 73046e0da361a..cc1289f397dff 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1342,6 +1342,186 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
   }
 };
 
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+                                  Location loc) {
+  SmallVector<ReassociationExprs, 1> reassociation;
+  // Create the collapsed type
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto elemType = inputType.getElementType();
+  auto collapsedType = RankedTensorType::get({}, elemType);
+  // Emit the collapse op
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+                                                  reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+    PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &multiplierConstant) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> multiplierExprs{
+      rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the multiplier
+    // values in a buffer.
+    if (multiplierValues.size() == 1) {
+      multiplierConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+    } else {
+      auto multiplierType =
+          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+                                rewriter.getI32Type());
+      genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+          loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    if (op.getMultiplier().getType().getRank() == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getMultiplier());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  }
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+    PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &shiftConstant) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the shift
+    // values in a buffer.
+    if (shiftValues.size() == 1) {
+      shiftConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+    } else {
+      auto shiftType =
+          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+                                rewriter.getIntegerType(8));
+      genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+          loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    if (op.getShift().getType().getRank() == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getShift(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getShift());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  }
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+                              FailureOr<int64_t> maybeZp, Location loc,
+                              ValueRange blockArgs) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[3];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    }
+  } else {
+    const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+    // Extend zeropoint for sub-32bits widths.
+    const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+                            FailureOr<int64_t> maybeZp, Location loc,
+                            ValueRange blockArgs) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[4];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+      result =
+          builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+    }
+  } else {
+    const int32_t attrBitwidth = 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 public:
   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1375,41 +1555,45 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 
     // The shift and multiplier values.
     DenseElementsAttr shiftElems;
-    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant shift input values");
+    bool isShiftConstant = false;
+    if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      isShiftConstant = true;
 
     DenseElementsAttr multiplierElems;
-    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant multiplier input values");
-
-    llvm::SmallVector<int8_t> shiftValues =
-        llvm::to_vector(shiftElems.getValues<int8_t>());
-    // explicit cast is required here
-    llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
-        llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
-                        [](IntegerAttr attr) -> int32_t {
-                          return static_cast<int32_t>(attr.getInt());
-                        }));
-
-    // If we shift by more than the bitwidth, this just sets to 0.
-    for (int i = 0, s = multiplierValues.size(); i < s; i++) {
-      if (shiftValues[i] > 63) {
-        shiftValues[i] = 0;
-        multiplierValues[i] = 0;
+    bool isMultiplierConstant = false;
+    if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      isMultiplierConstant = true;
+
+    llvm::SmallVector<int8_t> shiftValues;
+    llvm::SmallVector<int32_t> multiplierValues;
+    StringAttr roundingMode;
+    bool doubleRound;
+
+    if (isMultiplierConstant && isShiftConstant) {
+      shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+      // explicit cast is required here
+      multiplierValues = llvm::to_vector(
+          llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+                          [](IntegerAttr attr) -> int32_t {
+                            return static_cast<int32_t>(attr.getInt());
+                          }));
+
+      // If we shift by more than the bitwidth, this just sets to 0.
+      for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+        if (shiftValues[i] > 63) {
+          shiftValues[i] = 0;
+          multiplierValues[i] = 0;
+        }
       }
-    }
-
-    // Double round only occurs if shift is greater than 31, check that this
-    // is ever true.
+      // Double round only occurs if shift is greater than 31, check that this
+      // is ever true.
+      doubleRound = op.getRoundingMode() == "DOUBLE_ROUND" &&
+                    llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    } else
+      doubleRound = op.getRoundingMode() == "DOUBLE_ROUND";
 
-    bool doubleRound =
-        op.getRoundingMode() == "DOUBLE_ROUND" &&
-        llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
-    StringAttr roundingMode = doubleRound
-                                  ? rewriter.getStringAttr("DOUBLE_ROUND")
-                                  : rewriter.getStringAttr("SINGLE_ROUND");
+    roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND")
+                               : rewriter.getStringAttr("SINGLE_ROUND");
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1418,46 +1602,35 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // If we are rescaling per-channel then we need to store the multiplier
     // values in a buffer.
     Value multiplierConstant;
-    int64_t multiplierArg = 0;
-    if (multiplierValues.size() == 1) {
-      multiplierConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> multiplierExprs{
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto multiplierType =
-          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
-                                rewriter.getI32Type());
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc,
-          DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, multiplierExprs,
-                                            rewriter.getContext()));
-
-      multiplierArg = indexingMaps.size() - 1;
-    }
-
+    setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+        rewriter, multiplierValues, genericInputs, indexingMaps,
+        isMultiplierConstant, op, multiplierConstant);
     // If we are rescaling per-channel then we need to store the shift
     // values in a buffer.
     Value shiftConstant;
-    int64_t shiftArg = 0;
-    if (shiftValues.size() == 1) {
-      shiftConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> shiftExprs = {
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto shiftType =
-          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
-                                rewriter.getIntegerType(8));
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, shiftExprs,
-                                            rewriter.getContext()));
-      shiftArg = indexingMaps.size() - 1;
+    setupLinalgGenericOpInputAndIndexingMapForShift(
+        rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+        shiftConstant);
+
+    // broadcastMap = affine_map<(d0, d1) -> ()>
+    // It would affect as broadcast for scalar values in linalg::GenericOp.
+    AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+    FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+    FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+    // The inputZp and outputZp may be either constant or non-constant,
+    // depending on whether dynamic extension is enabled.
+    // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+    //     1. Pushing it into 'genericInputs'.
+    //     2. Appending a corresponding affine map to 'indexingMaps'.
+    if (failed(maybeIZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+      indexingMaps.push_back(broadcastMap);
+    }
+    if (failed(maybeOZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+      indexingMaps.push_back(broadcastMap);
     }
 
     // Indexing maps for output values.
@@ -1477,40 +1650,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Type valueTy = value.getType();
 
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
-          if (failed(maybeIZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "input zero point cannot be statically determined");
-            return;
-          }
-
-          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
-          // Extend zeropoint for sub-32bits widths.
-          const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
-          auto inputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
-                               *maybeIZp));
-
+          auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+                                          nestedLoc, blockArgs);
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
-          if (failed(maybeOZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "output zero point cannot be statically determined");
-            return;
-          };
+          auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+                                         nestedLoc, blockArgs);
 
           IntegerType outIntType =
               cast<IntegerType>(blockArgs.back().getType());
           unsigned outBitWidth = outIntType.getWidth();
-          const int32_t outAttrBitwidth = 32;
           assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
-          auto outputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
-                               *maybeOZp));
-
-          Value multiplier = multiplierConstant ? multiplierConstant
-                                                : blockArgs[multiplierArg];
-          Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
+
+          Value multiplier =
+              multiplierConstant ? multiplierConstant : blockArgs[1];
+          Value shift = shiftConstant ? shiftConstant : blockArgs[2];
 
           if (valueTy.isUnsignedInteger()) {
             value = UnrealizedConversionCastOp::create(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aee0caa91043d..8313173e1fec9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,34 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
 
 // -----
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+  // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c-128_i32 = arith.constant -128 : i32
+  // CHECK:    %c127_i32 = arith.constant 127 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false}...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir-tosa

Author: None (ShivaChen)

Changes

The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.


Patch is 20.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155967.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+249-96)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+28)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 73046e0da361a..cc1289f397dff 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1342,6 +1342,186 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
   }
 };
 
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+                                  Location loc) {
+  SmallVector<ReassociationExprs, 1> reassociation;
+  // Create the collapsed type
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto elemType = inputType.getElementType();
+  auto collapsedType = RankedTensorType::get({}, elemType);
+  // Emit the collapse op
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+                                                  reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+    PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &multiplierConstant) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> multiplierExprs{
+      rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the multiplier
+    // values in a buffer.
+    if (multiplierValues.size() == 1) {
+      multiplierConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+    } else {
+      auto multiplierType =
+          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+                                rewriter.getI32Type());
+      genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+          loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    if (op.getMultiplier().getType().getRank() == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getMultiplier());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  }
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+    PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &shiftConstant) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the shift
+    // values in a buffer.
+    if (shiftValues.size() == 1) {
+      shiftConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+    } else {
+      auto shiftType =
+          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+                                rewriter.getIntegerType(8));
+      genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+          loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    if (op.getShift().getType().getRank() == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getShift(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getShift());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  }
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+                              FailureOr<int64_t> maybeZp, Location loc,
+                              ValueRange blockArgs) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[3];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    }
+  } else {
+    const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+    // Extend zeropoint for sub-32bits widths.
+    const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+                            FailureOr<int64_t> maybeZp, Location loc,
+                            ValueRange blockArgs) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[4];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+      result =
+          builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+    }
+  } else {
+    const int32_t attrBitwidth = 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 public:
   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1375,41 +1555,45 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 
     // The shift and multiplier values.
     DenseElementsAttr shiftElems;
-    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant shift input values");
+    bool isShiftConstant = false;
+    if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      isShiftConstant = true;
 
     DenseElementsAttr multiplierElems;
-    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant multiplier input values");
-
-    llvm::SmallVector<int8_t> shiftValues =
-        llvm::to_vector(shiftElems.getValues<int8_t>());
-    // explicit cast is required here
-    llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
-        llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
-                        [](IntegerAttr attr) -> int32_t {
-                          return static_cast<int32_t>(attr.getInt());
-                        }));
-
-    // If we shift by more than the bitwidth, this just sets to 0.
-    for (int i = 0, s = multiplierValues.size(); i < s; i++) {
-      if (shiftValues[i] > 63) {
-        shiftValues[i] = 0;
-        multiplierValues[i] = 0;
+    bool isMultiplierConstant = false;
+    if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      isMultiplierConstant = true;
+
+    llvm::SmallVector<int8_t> shiftValues;
+    llvm::SmallVector<int32_t> multiplierValues;
+    StringAttr roundingMode;
+    bool doubleRound;
+
+    if (isMultiplierConstant && isShiftConstant) {
+      shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+      // explicit cast is required here
+      multiplierValues = llvm::to_vector(
+          llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+                          [](IntegerAttr attr) -> int32_t {
+                            return static_cast<int32_t>(attr.getInt());
+                          }));
+
+      // If we shift by more than the bitwidth, this just sets to 0.
+      for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+        if (shiftValues[i] > 63) {
+          shiftValues[i] = 0;
+          multiplierValues[i] = 0;
+        }
       }
-    }
-
-    // Double round only occurs if shift is greater than 31, check that this
-    // is ever true.
+      // Double round only occurs if shift is greater than 31, check that this
+      // is ever true.
+      doubleRound = op.getRoundingMode() == "DOUBLE_ROUND" &&
+                    llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    } else
+      doubleRound = op.getRoundingMode() == "DOUBLE_ROUND";
 
-    bool doubleRound =
-        op.getRoundingMode() == "DOUBLE_ROUND" &&
-        llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
-    StringAttr roundingMode = doubleRound
-                                  ? rewriter.getStringAttr("DOUBLE_ROUND")
-                                  : rewriter.getStringAttr("SINGLE_ROUND");
+    roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND")
+                               : rewriter.getStringAttr("SINGLE_ROUND");
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1418,46 +1602,35 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // If we are rescaling per-channel then we need to store the multiplier
     // values in a buffer.
     Value multiplierConstant;
-    int64_t multiplierArg = 0;
-    if (multiplierValues.size() == 1) {
-      multiplierConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> multiplierExprs{
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto multiplierType =
-          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
-                                rewriter.getI32Type());
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc,
-          DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, multiplierExprs,
-                                            rewriter.getContext()));
-
-      multiplierArg = indexingMaps.size() - 1;
-    }
-
+    setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+        rewriter, multiplierValues, genericInputs, indexingMaps,
+        isMultiplierConstant, op, multiplierConstant);
     // If we are rescaling per-channel then we need to store the shift
     // values in a buffer.
     Value shiftConstant;
-    int64_t shiftArg = 0;
-    if (shiftValues.size() == 1) {
-      shiftConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> shiftExprs = {
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto shiftType =
-          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
-                                rewriter.getIntegerType(8));
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, shiftExprs,
-                                            rewriter.getContext()));
-      shiftArg = indexingMaps.size() - 1;
+    setupLinalgGenericOpInputAndIndexingMapForShift(
+        rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+        shiftConstant);
+
+    // broadcastMap = affine_map<(d0, d1) -> ()>
+    // It would affect as broadcast for scalar values in linalg::GenericOp.
+    AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+    FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+    FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+    // The inputZp and outputZp may be either constant or non-constant,
+    // depending on whether dynamic extension is enabled.
+    // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+    //     1. Pushing it into 'genericInputs'.
+    //     2. Appending a corresponding affine map to 'indexingMaps'.
+    if (failed(maybeIZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+      indexingMaps.push_back(broadcastMap);
+    }
+    if (failed(maybeOZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+      indexingMaps.push_back(broadcastMap);
     }
 
     // Indexing maps for output values.
@@ -1477,40 +1650,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Type valueTy = value.getType();
 
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
-          if (failed(maybeIZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "input zero point cannot be statically determined");
-            return;
-          }
-
-          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
-          // Extend zeropoint for sub-32bits widths.
-          const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
-          auto inputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
-                               *maybeIZp));
-
+          auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+                                          nestedLoc, blockArgs);
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
-          if (failed(maybeOZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "output zero point cannot be statically determined");
-            return;
-          };
+          auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+                                         nestedLoc, blockArgs);
 
           IntegerType outIntType =
               cast<IntegerType>(blockArgs.back().getType());
           unsigned outBitWidth = outIntType.getWidth();
-          const int32_t outAttrBitwidth = 32;
           assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
-          auto outputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
-                               *maybeOZp));
-
-          Value multiplier = multiplierConstant ? multiplierConstant
-                                                : blockArgs[multiplierArg];
-          Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
+
+          Value multiplier =
+              multiplierConstant ? multiplierConstant : blockArgs[1];
+          Value shift = shiftConstant ? shiftConstant : blockArgs[2];
 
           if (valueTy.isUnsignedInteger()) {
             value = UnrealizedConversionCastOp::create(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aee0caa91043d..8313173e1fec9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,34 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
 
 // -----
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+  // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c-128_i32 = arith.constant -128 : i32
+  // CHECK:    %c127_i32 = arith.constant 127 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false}...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants