Skip to content

Commit e3ddbed

Browse files
Xin Zhoutensorflower-gardener
authored andcommitted
[mhlo] Add type inference for stablehlo.select_and_scatter
Integrated from PR openxla/stablehlo#405 of StableHLO. PiperOrigin-RevId: 495641144
1 parent bd003ed commit e3ddbed

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7183,9 +7183,17 @@ OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
71837183
// SelectAndScatterOp
71847184
//===----------------------------------------------------------------------===//
71857185

7186+
LogicalResult SelectAndScatterOp::inferReturnTypes(
7187+
MLIRContext*, Optional<Location>, ValueRange operands,
7188+
DictionaryAttr attributes, RegionRange regions,
7189+
SmallVectorImpl<Type>& inferredReturnTypes) {
7190+
SelectAndScatterOp::Adaptor adaptor(operands, attributes, regions);
7191+
return hlo::inferSelectAndScatterOp(adaptor.getOperand(),
7192+
inferredReturnTypes);
7193+
}
7194+
71867195
namespace {
7187-
// Infer the return-type of SelectAndScatterOp.
7188-
TensorType inferSelectAndScatterOpReturnType(
7196+
TensorType inferSelectAndScatterOpWindowReturnType(
71897197
TensorType operandType, const ArrayRef<hlo::WindowDimension> window) {
71907198
if (!operandType.hasRank())
71917199
return UnrankedTensorType::get(operandType.getElementType());
@@ -7203,13 +7211,11 @@ TensorType inferSelectAndScatterOpReturnType(
72037211
// P3. size-of(window_dimension) == rank-of(input),
72047212
// where input is an element of 'inputs'.
72057213
// P4. Verify and collect the window attributes.
7206-
// P5. Verify the return type matches the operand-type.
7207-
// P6. Check if the result type of window operation matches the source type.
7214+
// P5. Check if the result type of window operation matches the source type.
72087215
LogicalResult SelectAndScatterOp::verify() {
72097216
auto operandType = getOperand().getType().cast<TensorType>();
72107217
auto initValueType = getInitValue().getType().cast<TensorType>();
72117218
auto sourceType = getSource().getType().cast<TensorType>();
7212-
auto resultType = getResult().getType().cast<TensorType>();
72137219

72147220
// P1.
72157221
Block& selectBlock = getSelect().front();
@@ -7277,14 +7283,8 @@ LogicalResult SelectAndScatterOp::verify() {
72777283
if (failed(windowOrErr)) return failure();
72787284

72797285
// P5.
7280-
if (!hlo::compatibleShapeAndElementType(operandType, resultType))
7281-
return emitOpError()
7282-
<< "expects the return-type to match the operand-type, but got "
7283-
<< resultType << " and " << operandType << " resp.";
7284-
7285-
// P6.
72867286
auto windowResultType =
7287-
inferSelectAndScatterOpReturnType(operandType, *windowOrErr);
7287+
inferSelectAndScatterOpWindowReturnType(operandType, *windowOrErr);
72887288

72897289
if (!hlo::compatibleShapeAndElementType(windowResultType, sourceType,
72907290
/*ignoreFpPrecision=*/true))

tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2555,7 +2555,7 @@ def MHLO_SelectOp: MHLO_Op<"select", [Pure, HLO_BroadcastingElementwise,
25552555
}
25562556

25572557
def MHLO_SelectAndScatterOp: MHLO_Op<"select_and_scatter",
2558-
[RecursiveMemoryEffects]> {
2558+
[RecursiveMemoryEffects, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
25592559
let summary = "SelectAndScatter operator";
25602560
let description = [{
25612561
Runs a windowed selection `select` function over `operand` with shape

tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,37 @@ func.func @after_all(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token {
605605
func.return %1 : !mhlo.token
606606
}
607607

608+
// -----
609+
610+
// CHECK: func @select_and_scatter
611+
func.func @select_and_scatter(
612+
%arg0: tensor<10x24x24x64xf32>,
613+
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xindex> {
614+
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
615+
616+
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
617+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
618+
%2 = "mhlo.compare"(%arg3, %arg4) {
619+
compare_type = #mhlo<comparison_type TOTALORDER>,
620+
comparison_direction = #mhlo<comparison_direction GE>
621+
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
622+
"mhlo.return"(%2) : (tensor<i1>) -> ()
623+
}, {
624+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
625+
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
626+
"mhlo.return"(%2) : (tensor<f32>) -> ()
627+
}) {
628+
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
629+
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
630+
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
631+
tensor<10x24x24x64xf32>
632+
%3 = "mhlo_test.get_return_types"(%1) : (tensor<10x24x24x64xf32>) -> tensor<10x24x24x64xindex>
633+
// CHECK: %2 = "mhlo_test.return_types"(%1) {types0 = tensor<10x24x24x64xf32>} : (tensor<10x24x24x64xf32>) -> tensor<10x24x24x64xindex>
634+
func.return %3 : tensor<10x24x24x64xindex>
635+
}
636+
637+
// -----
638+
608639
//===----------------------------------------------------------------------===//
609640
// Sparsity
610641
//===----------------------------------------------------------------------===//

tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ func.func @select_and_scatter_invalid_ret_type(
394394
%arg1: tensor<10x12x12x64xf32>) -> () {
395395
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
396396

397-
// expected-error @+1 {{expects the return-type to match the operand-type, but got 'tensor<10x24x24x32xf32>' and 'tensor<10x24x24x64xf32>' resp.}}
397+
// expected-error @+1 {{inferred type(s) 'tensor<10x24x24x64xf32>' are incompatible with return type(s) of operation 'tensor<10x24x24x32xf32>'}}
398398
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
399399
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
400400
%2 = "mhlo.compare"(%arg3, %arg4) {
@@ -422,7 +422,7 @@ func.func @select_and_scatter_invalid_ret_type(
422422
%arg1: tensor<10x12x12x64xf32>) -> () {
423423
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
424424

425-
// expected-error @+1 {{expects the return-type to match the operand-type, but got 'tensor<10x24x24x64xi32>' and 'tensor<10x24x24x64xf32>' resp.}}
425+
// expected-error @+1 {{inferred type(s) 'tensor<10x24x24x64xf32>' are incompatible with return type(s) of operation 'tensor<10x24x24x64xi32>'}}
426426
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
427427
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
428428
%2 = "mhlo.compare"(%arg3, %arg4) {

0 commit comments

Comments
 (0)