Skip to content

Conversation

PragmaTwice
Copy link
Contributor

@PragmaTwice PragmaTwice commented Sep 4, 2025

Currently in MLIR python bindings, operations with inferable result types (e.g. with InferTypeOpInterface or SameOperandsAndResultType) will generate such builder functions:

def my_op(arg1, arg2 .. argN, *, loc=None, ip=None):
  ... # result types will be inferred automatically

However, in some cases we may want to provide the result types explicitly. For example, the implementation of interface method inferResultTypes(..) can return a failure and then we cannot build the op in that way. Also, in the C++ side we have multiple build methods for both explicitly specify the result types and automatically inferring them.

In this PR, we change the signature of this builder function to:

def my_op(arg1, arg2 .. argN, *, results=None, loc=None, ip=None):
  ... # result types will be inferred automatically if results is None

If the results is not provided, it will be inferred automatically, otherwise the provided result types will be utilized. Also, __init__ methods of the generated op classes are changed correspondingly. Note that for operations without inferable result types, the signature remain unchanged, i.e. def my_op(res1 .. resN, arg1 .. argN, *, loc=None, ip=None).


Previously I have considered an approach like my_op(arg, *, res1=None, res2=None, loc=None, ip=None), but I quickly realized it had some issues. For example, if the user only provides some of the arguments—say my_op(v1, res1=i32)—this could lead to problems. Moreover, we don’t seem to have a mechanism for inferring only part of result types. A unified results parameter seems to be more simple and straightforward.

@PragmaTwice PragmaTwice marked this pull request as ready for review September 4, 2025 07:26
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Twice (PragmaTwice)

Changes

Currently in MLIR python bindings, operations with inferable result types (e.g. with InferTypeOpInterface or SameOperandsAndResultType) will generate such builder functions:

def my_op(arg1, arg2 .. argN, *, loc=None, ip=None):
  ... # result types will be inferred automatically

However, in some cases we may want to provide the result types explicitly. For example, the implementation of interface method inferResultTypes(..) can return a failure and then we cannot build the op in that way. Also, in the C++ side we have multiple build methods for both explicitly specify the result types and automatically inferring them.

In this PR, we change the signature of this builder function to:

def my_op(arg1, arg2 .. argN, *, results=None, loc=None, ip=None):
  ... # result types will be inferred automatically if results is None

If the results is not provided, it will be inferred automatically, otherwise the provided result types will be utilized. Also, __init__ methods of the generated op classes are changed correspondingly.


Patch is 20.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156818.diff

3 Files Affected:

  • (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+33-34)
  • (modified) mlir/test/python/ir/auto_location.py (+3-3)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+27-19)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index c2bd86819666b..4d5d7ee26b775 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -23,12 +23,12 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
                                  [AttrSizedOperandSegments]> {
   // CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   operands.append(_get_op_results_or_values(variadic1))
   // CHECK:   operands.append(non_variadic)
   // CHECK:   operands.append(variadic2)
+  // CHECK:   results = []
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -71,9 +71,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
                                [AttrSizedResultSegments]> {
   // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
+  // CHECK:   results = []
   // CHECK:   if variadic1 is not None: results.append(variadic1)
   // CHECK:   results.append(non_variadic)
   // CHECK:   results.append(variadic2)
@@ -120,7 +120,6 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
 def AttributedOp : TestOp<"attributed_op"> {
   // CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   attributes["i32attr"] = (i32attr if (
@@ -131,6 +130,7 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:   if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
   // CHECK:   attributes["in"] = (in_
+  // CHECK:   results = []
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -170,7 +170,6 @@ def AttributedOp : TestOp<"attributed_op"> {
 def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   operands.append(_gen_arg_0)
@@ -178,6 +177,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
   // CHECK:   if is_ is not None: attributes["is"] = (is_
+  // CHECK:   results = []
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -205,11 +205,11 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
   // CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   if arr is not None: attributes["arr"] = (arr
   // CHECK:   if unsupported is not None: attributes["unsupported"] = (unsupported
+  // CHECK:   results = []
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -226,21 +226,21 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
-  // CHECK: def __init__(self, type_, *, loc=None, ip=None):
+  // CHECK: def __init__(self, type_, *, results=None, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
-  // CHECK:   _ods_result_type_source_attr = attributes["type"]
-  // CHECK:   _ods_derived_result_type = (
+  // CHECK:   if results is None:
+  // CHECK:     _ods_result_type_source_attr = attributes["type"]
+  // CHECK:     _ods_derived_result_type = (
   // CHECK:       _ods_ir.TypeAttr(_ods_result_type_source_attr).value
   // CHECK:       if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
   // CHECK:       _ods_result_type_source_attr.type)
-  // CHECK:   results.extend([_ods_derived_result_type] * 2)
+  // CHECK:   results = [_ods_derived_result_type] * 2
   let arguments = (ins TypeAttr:$type);
   let results = (outs AnyType:$res, AnyType);
 }
 
-// CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return DeriveResultTypesOp(type_=type_, loc=loc, ip=ip).results
+// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None)
+// CHECK:   return DeriveResultTypesOp(type_=type_, results=results, loc=loc, ip=ip).results
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -258,9 +258,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
 def EmptyOp : TestOp<"empty">;
   // CHECK: def __init__(self, *, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
+  // CHECK:   results = []
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -272,31 +272,31 @@ def EmptyOp : TestOp<"empty">;
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
-  // CHECK:  def __init__(self, *, loc=None, ip=None):
+  // CHECK:  def __init__(self, *, results=None, loc=None, ip=None):
   // CHECK:    _ods_context = _ods_get_default_loc_context(loc)
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
-  // CHECK:     attributes=attributes, operands=operands,
+  // CHECK:     attributes=attributes, results=results, operands=operands,
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
   let results = (outs I32:$i32, F32:$f32);
 }
 
-// CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return InferResultTypesImpliedOp(loc=loc, ip=ip).results
+// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None)
+// CHECK:   return InferResultTypesImpliedOp(results=results, loc=loc, ip=ip).results
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
-  // CHECK:  def __init__(self, *, loc=None, ip=None):
+  // CHECK:  def __init__(self, *, results=None, loc=None, ip=None):
   // CHECK:    operands = []
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
-  // CHECK:     attributes=attributes, operands=operands,
+  // CHECK:     attributes=attributes, results=results, operands=operands,
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip)
   let results = (outs AnyType, AnyType, AnyType);
 }
 
-// CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return InferResultTypesOp(loc=loc, ip=ip).results
+// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None)
+// CHECK:   return InferResultTypesOp(results=results, loc=loc, ip=ip).results
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -304,12 +304,12 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   operands.append(_gen_arg_0)
   // CHECK:   operands.append(f32)
   // CHECK:   operands.append(_gen_arg_2)
+  // CHECK:   results = []
   // CHECK:   results.append(i32)
   // CHECK:   results.append(_gen_res_1)
   // CHECK:   results.append(i64)
@@ -346,11 +346,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
   // CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   operands.append(non_optional)
   // CHECK:   if optional is not None: operands.append(optional)
+  // CHECK:   results = []
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -377,11 +377,11 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
 def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   operands.append(non_variadic)
   // CHECK:   operands.extend(_get_op_results_or_values(variadic))
+  // CHECK:   results = []
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -410,9 +410,9 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
 def OneVariadicResultOp : TestOp<"one_variadic_result"> {
   // CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
+  // CHECK:   results = []
   // CHECK:   results.extend(variadic)
   // CHECK:   results.append(non_variadic)
   // CHECK:   _ods_successors = None
@@ -442,10 +442,10 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
 def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK: def __init__(self, in_, *, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   operands.append(in_)
+  // CHECK:   results = []
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -463,17 +463,16 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
-  // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None):
+  // CHECK: def __init__(self, in1, in2, *, results=None, loc=None, ip=None):
   // CHECK: operands = []
-  // CHECK: results = []
   // CHECK: operands.append
-  // CHECK: results.extend([operands[0].type] * 1)
+  // CHECK: if results is None: results = [operands[0].type] * 1
   let arguments = (ins AnyType:$in1, AnyType:$in2);
   let results = (outs AnyType:$res);
 }
 
-// CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)
+// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None)
+// CHECK:   return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip)
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -544,11 +543,11 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
 def SimpleOp : TestOp<"simple"> {
   // CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None):
   // CHECK:   operands = []
-  // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   operands.append(i32)
   // CHECK:   operands.append(f32)
+  // CHECK:   results = []
   // CHECK:   results.append(i64)
   // CHECK:   results.append(f64)
   // CHECK:   _ods_successors = None
@@ -584,9 +583,9 @@ def SimpleOp : TestOp<"simple"> {
 def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
   // CHECK:  def __init__(self, num_variadic, *, loc=None, ip=None):
   // CHECK:    operands = []
-  // CHECK:    results = []
   // CHECK:    attributes = {}
   // CHECK:    regions = None
+  // CHECK:    results = []
   // CHECK:    _ods_successors = None
   // CHECK:    regions = 2 + num_variadic
   // CHECK:    super().__init__(
@@ -612,9 +611,9 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
 def VariadicRegionOp : TestOp<"variadic_region"> {
   // CHECK:  def __init__(self, num_variadic, *, loc=None, ip=None):
   // CHECK:    operands = []
-  // CHECK:    results = []
   // CHECK:    attributes = {}
   // CHECK:    regions = None
+  // CHECK:    results = []
   // CHECK:    _ods_successors = None
   // CHECK:    regions = 0 + num_variadic
   // CHECK:    super().__init__(
diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py
index a45ca48b5c484..01b5542119b4e 100644
--- a/mlir/test/python/ir/auto_location.py
+++ b/mlir/test/python/ir/auto_location.py
@@ -51,7 +51,7 @@ def testInferLocations():
         _cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
         three = arith.constant(IndexType.get(), 3)
         # fmt: off
-        # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
+        # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
         # fmt: on
         print(three.location)
 
@@ -60,14 +60,14 @@ def foo():
             print(four.location)
 
         # fmt: off
-        # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
+        # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
         # fmt: on
         foo()
 
         _cext.globals.register_traceback_file_exclusion(__file__)
 
         # fmt: off
-        # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218))
+        # CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235))
         # fmt: on
         foo()
 
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 038f56d5a2150..6a7aa9e3432d5 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -492,7 +492,6 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
 constexpr const char *initTemplate = R"Py(
   def __init__(self, {0}):
     operands = []
-    results = []
     attributes = {{}
     regions = None
     {1}
@@ -738,18 +737,24 @@ populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
   }
 }
 
-/// Python code template for deriving the operation result types from its
-/// attribute:
+/// Python code template of generating result types for
+/// FirstAttrDerivedResultType trait
 ///   - {0} is the name of the attribute from which to derive the types.
-constexpr const char *deriveTypeFromAttrTemplate =
-    R"Py(_ods_result_type_source_attr = attributes["{0}"]
-_ods_derived_result_type = (
+///   - {1} is the number of results.
+constexpr const char *firstAttrDerivedResultTypeTemplate =
+    R"Py(if results is None:
+  _ods_result_type_source_attr = attributes["{0}"]
+  _ods_derived_result_type = (
     _ods_ir.TypeAttr(_ods_result_type_source_attr).value
     if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
-    _ods_result_type_source_attr.type))Py";
+    _ods_result_type_source_attr.type)
+  results = [_ods_derived_result_type] * {1})Py";
 
-/// Python code template appending {0} type {1} times to the results list.
-constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
+/// Python code template of generating result types for
+/// SameOperandsAndResultType trait
+///   - {0} is the number of results.
+constexpr const char *sameOperandsAndResultTypeTemplate =
+    R"Py(if results is None: results = [operands[0].type] * {0})Py";
 
 /// Appends the given multiline string as individual strings into
 /// `builderLines`.
@@ -768,11 +773,10 @@ static void appendLineByLine(StringRef string,
 static void
 populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
                            SmallVectorImpl<std::string> &builderLines) {
-  bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
-
   if (hasSameArgumentAndResultTypes(op)) {
-    builderLines.push_back(formatv(appendSameResultsTemplate,
-                                   "operands[0].type", op.getNumResults()));
+    appendLineByLine(
+        formatv(sameOperandsAndResultTypeTemplate, op.getNumResults()).str(),
+        builderLines);
     return;
   }
 
@@ -780,17 +784,19 @@ populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
     const NamedAttribute &firstAttr = op.getAttribute(0);
     assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
                                       "from which the type is derived");
-    appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
+    appendLineByLine(formatv(firstAttrDerivedResultTypeTemplate, firstAttr.name,
+                             op.getNumResults())
+                         .str(),
                      builderLines);
-    builderLines.push_back(formatv(appendSameResultsTemplate,
-                                   "_ods_derived_result_type",
-                                   op.getNumResults()));
     return;
   }
 
   if (hasInferTypeInterface(op))
     return;
 
+  bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
+  builderLines.push_back("results = []");
+
   // For each element, find or generate a name.
   for (int i = 0, e = op.getNumResults(); i < e; ++i...
[truncated]

@makslevental
Copy link
Contributor

Ya this all makes sense expect for one smell part:

A unified results parameter seems to be more simple and straightforward.

I don't understand - why not just expect users to provide a list with the correct types instead of doing [type] * n?

@makslevental
Copy link
Contributor

I know it's tedious but can you add a python example? I think we're both pretty sure the code is correct (vis-a-vis the td test) but the reason is the python tests also serve as examples on how to use the bindings.

@PragmaTwice
Copy link
Contributor Author

I know it's tedious but can you add a python example? I think we're both pretty sure the code is correct (vis-a-vis the td test) but the reason is the python tests also serve as examples on how to use the bindings.

Sure! Added in 7d434c9.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool this is a nice improvement (also surprising we never ran into issues with the previous implementation...).

@makslevental
Copy link
Contributor

Let me know when you're ready for me to merge this.

@PragmaTwice
Copy link
Contributor Author

Let me know when you're ready for me to merge this.

Thank you for your review. I think it is ready now : )

@makslevental makslevental merged commit 55d9c91 into llvm:main Sep 5, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants