-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][Python] Add optional results
parameter for building op with inferable result types
#156818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Twice (PragmaTwice) ChangesCurrently in MLIR python bindings, operations with inferable result types (e.g. with def my_op(arg1, arg2 .. argN, *, loc=None, ip=None):
... # result types will be inferred automatically However, in some cases we may want to provide the result types explicitly. For example, the implementation of interface method In this PR, we change the signature of this builder function to: def my_op(arg1, arg2 .. argN, *, results=None, loc=None, ip=None):
... # result types will be inferred automatically if results is None If the Patch is 20.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156818.diff 3 Files Affected:
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index c2bd86819666b..4d5d7ee26b775 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -23,12 +23,12 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
[AttrSizedOperandSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(_get_op_results_or_values(variadic1))
// CHECK: operands.append(non_variadic)
// CHECK: operands.append(variadic2)
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -71,9 +71,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
[AttrSizedResultSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: if variadic1 is not None: results.append(variadic1)
// CHECK: results.append(non_variadic)
// CHECK: results.append(variadic2)
@@ -120,7 +120,6 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
def AttributedOp : TestOp<"attributed_op"> {
// CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: attributes["i32attr"] = (i32attr if (
@@ -131,6 +130,7 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: attributes["in"] = (in_
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -170,7 +170,6 @@ def AttributedOp : TestOp<"attributed_op"> {
def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(_gen_arg_0)
@@ -178,6 +177,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = (is_
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -205,11 +205,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: if arr is not None: attributes["arr"] = (arr
// CHECK: if unsupported is not None: attributes["unsupported"] = (unsupported
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -226,21 +226,21 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
- // CHECK: def __init__(self, type_, *, loc=None, ip=None):
+ // CHECK: def __init__(self, type_, *, results=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
- // CHECK: _ods_result_type_source_attr = attributes["type"]
- // CHECK: _ods_derived_result_type = (
+ // CHECK: if results is None:
+ // CHECK: _ods_result_type_source_attr = attributes["type"]
+ // CHECK: _ods_derived_result_type = (
// CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value
// CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
// CHECK: _ods_result_type_source_attr.type)
- // CHECK: results.extend([_ods_derived_result_type] * 2)
+ // CHECK: results = [_ods_derived_result_type] * 2
let arguments = (ins TypeAttr:$type);
let results = (outs AnyType:$res, AnyType);
}
-// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK: return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
+// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None)
+// CHECK: return DeriveResultTypesOp(type_=type_, results=results, loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -258,9 +258,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
def EmptyOp : TestOp<"empty">;
// CHECK: def __init__(self, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -272,31 +272,31 @@ def EmptyOp : TestOp<"empty">;
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
- // CHECK: def __init__(self, *, loc=None, ip=None):
+ // CHECK: def __init__(self, *, results=None, loc=None, ip=None):
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
- // CHECK: attributes=attributes, operands=operands,
+ // CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
let results = (outs I32:$i32, F32:$f32);
}
-// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK: return InferResultTypesImpliedOp(loc=loc, ip=ip).results
+// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None)
+// CHECK: return InferResultTypesImpliedOp(results=results, loc=loc, ip=ip).results
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
- // CHECK: def __init__(self, *, loc=None, ip=None):
+ // CHECK: def __init__(self, *, results=None, loc=None, ip=None):
// CHECK: operands = []
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
- // CHECK: attributes=attributes, operands=operands,
+ // CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
let results = (outs AnyType, AnyType, AnyType);
}
-// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK: return InferResultTypesOp(loc=loc, ip=ip).results
+// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None)
+// CHECK: return InferResultTypesOp(results=results, loc=loc, ip=ip).results
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -304,12 +304,12 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(_gen_arg_0)
// CHECK: operands.append(f32)
// CHECK: operands.append(_gen_arg_2)
+ // CHECK: results = []
// CHECK: results.append(i32)
// CHECK: results.append(_gen_res_1)
// CHECK: results.append(i64)
@@ -346,11 +346,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
// CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(non_optional)
// CHECK: if optional is not None: operands.append(optional)
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -377,11 +377,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(non_variadic)
// CHECK: operands.extend(_get_op_results_or_values(variadic))
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -410,9 +410,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
def OneVariadicResultOp : TestOp<"one_variadic_result"> {
// CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: results.extend(variadic)
// CHECK: results.append(non_variadic)
// CHECK: _ods_successors = None
@@ -442,10 +442,10 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: def __init__(self, in_, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(in_)
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -463,17 +463,16 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK-LABEL: OPERATION_NAME = "test.same_results"
def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
- // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
+ // CHECK: def __init__(self, in1, in2, *, results=None, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: operands.append
- // CHECK: results.extend([operands[0].type] * 1)
+ // CHECK: if results is None: results = [operands[0].type] * 1
let arguments = (ins AnyType:$in1, AnyType:$in2);
let results = (outs AnyType:$res);
}
-// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK: return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
+// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None)
+// CHECK: return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip)
// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -544,11 +543,11 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
def SimpleOp : TestOp<"simple"> {
// CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(i32)
// CHECK: operands.append(f32)
+ // CHECK: results = []
// CHECK: results.append(i64)
// CHECK: results.append(f64)
// CHECK: _ods_successors = None
@@ -584,9 +583,9 @@ def SimpleOp : TestOp<"simple"> {
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: regions = 2 + num_variadic
// CHECK: super().__init__(
@@ -612,9 +611,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: def __init__(self, num_variadic, *, loc=None, ip=None):
// CHECK: operands = []
- // CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
+ // CHECK: results = []
// CHECK: _ods_successors = None
// CHECK: regions = 0 + num_variadic
// CHECK: super().__init__(
diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py
index a45ca48b5c484..01b5542119b4e 100644
--- a/mlir/test/python/ir/auto_location.py
+++ b/mlir/test/python/ir/auto_location.py
@@ -51,7 +51,7 @@ def testInferLocations():
_cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
three = arith.constant(IndexType.get(), 3)
# fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
# fmt: on
print(three.location)
@@ -60,14 +60,14 @@ def foo():
print(four.location)
# fmt: off
- # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
+ # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
# fmt: on
foo()
_cext.globals.register_traceback_file_exclusion(__file__)
# fmt: off
- # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218))
+ # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235))
# fmt: on
foo()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 038f56d5a2150..6a7aa9e3432d5 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -492,7 +492,6 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
constexpr const char *initTemplate = R"Py(
def __init__(self, {0}):
operands = []
- results = []
attributes = {{}
regions = None
{1}
@@ -738,18 +737,24 @@ populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
}
}
-/// Python code template for deriving the operation result types from its
-/// attribute:
+/// Python code template of generating result types for
+/// FirstAttrDerivedResultType trait
/// - {0} is the name of the attribute from which to derive the types.
-constexpr const char *deriveTypeFromAttrTemplate =
- R"Py(_ods_result_type_source_attr = attributes["{0}"]
-_ods_derived_result_type = (
+/// - {1} is the number of results.
+constexpr const char *firstAttrDerivedResultTypeTemplate =
+ R"Py(if results is None:
+ _ods_result_type_source_attr = attributes["{0}"]
+ _ods_derived_result_type = (
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
- _ods_result_type_source_attr.type))Py";
+ _ods_result_type_source_attr.type)
+ results = [_ods_derived_result_type] * {1})Py";
-/// Python code template appending {0} type {1} times to the results list.
-constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
+/// Python code template of generating result types for
+/// SameOperandsAndResultType trait
+/// - {0} is the number of results.
+constexpr const char *sameOperandsAndResultTypeTemplate =
+ R"Py(if results is None: results = [operands[0].type] * {0})Py";
/// Appends the given multiline string as individual strings into
/// `builderLines`.
@@ -768,11 +773,10 @@ static void appendLineByLine(StringRef string,
static void
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
SmallVectorImpl<std::string> &builderLines) {
- bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
-
if (hasSameArgumentAndResultTypes(op)) {
- builderLines.push_back(formatv(appendSameResultsTemplate,
- "operands[0].type", op.getNumResults()));
+ appendLineByLine(
+ formatv(sameOperandsAndResultTypeTemplate, op.getNumResults()).str(),
+ builderLines);
return;
}
@@ -780,17 +784,19 @@ populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
const NamedAttribute &firstAttr = op.getAttribute(0);
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
"from which the type is derived");
- appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
+ appendLineByLine(formatv(firstAttrDerivedResultTypeTemplate, firstAttr.name,
+ op.getNumResults())
+ .str(),
builderLines);
- builderLines.push_back(formatv(appendSameResultsTemplate,
- "_ods_derived_result_type",
- op.getNumResults()));
return;
}
if (hasInferTypeInterface(op))
return;
+ bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
+ builderLines.push_back("results = []");
+
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i...
[truncated]
|
Ya this all makes sense expect for one smell part:
I don't understand - why not just expect users to provide a list with the correct types instead of doing |
I know it's tedious but can you add a python example? I think we're both pretty sure the code is correct (vis-a-vis the td test) but the reason is the python tests also serve as examples on how to use the bindings. |
Sure! Added in 7d434c9. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool this is a nice improvement (also surprising we never ran into issues with the previous implementation...).
Let me know when you're ready for me to merge this. |
Thank you for your review. I think it is ready now : ) |
Currently in MLIR python bindings, operations with inferable result types (e.g. with
InferTypeOpInterface
orSameOperandsAndResultType
) will generate such builder functions:However, in some cases we may want to provide the result types explicitly. For example, the implementation of interface method
inferResultTypes(..)
can return a failure and then we cannot build the op in that way. Also, in the C++ side we have multiplebuild
methods for both explicitly specify the result types and automatically inferring them.In this PR, we change the signature of this builder function to:
If the
results
is not provided, it will be inferred automatically, otherwise the provided result types will be utilized. Also,__init__
methods of the generated op classes are changed correspondingly. Note that for operations without inferable result types, the signature remain unchanged, i.e.def my_op(res1 .. resN, arg1 .. argN, *, loc=None, ip=None)
.Previously I have considered an approach like
my_op(arg, *, res1=None, res2=None, loc=None, ip=None)
, but I quickly realized it had some issues. For example, if the user only provides some of the arguments—saymy_op(v1, res1=i32)
—this could lead to problems. Moreover, we don’t seem to have a mechanism for inferring only part of result types. A unifiedresults
parameter seems to be more simple and straightforward.