diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -19,7 +19,7 @@ """Specialization for the constant op class.""" def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): - super().__init__(result, value, loc=loc, ip=ip) + super().__init__(value, results=[result], loc=loc, ip=ip) @property def type(self): @@ -273,11 +273,11 @@ "to a function") super().__init__( - calleeOrResults.type.results, FlatSymbolRefAttr.get( calleeOrResults.name.value, context=_get_default_loc_context(loc)), argumentsOrCallee, + results=calleeOrResults.type.results, loc=loc, ip=ip) return @@ -289,12 +289,12 @@ if isinstance(argumentsOrCallee, FlatSymbolRefAttr): super().__init__( - calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) + argumentsOrCallee, arguments, results=calleeOrResults, loc=loc, ip=ip) elif isinstance(argumentsOrCallee, str): super().__init__( - calleeOrResults, FlatSymbolRefAttr.get( argumentsOrCallee, context=_get_default_loc_context(loc)), arguments, + results=calleeOrResults, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -34,4 +34,4 @@ indices_resolved = [] if indices is None else _get_op_results_or_values( indices) return_type = MemRefType(memref_resolved.type).element_type - super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) + super().__init__(memref, indices_resolved, results=[return_type], loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -79,7 +79,7 @@ ip=None): name = _get_str_attr(name) args = _get_values(args) - super().__init__(results, name, args, loc=loc, ip=ip) + super().__init__(name, args, results=results, loc=loc, ip=ip) class AttributeOp: @@ -93,7 +93,7 @@ ip=None): type = type if type is None else _get_value(type) result = pdl.AttributeType.get() - super().__init__(result, type=type, value=value, loc=loc, ip=ip) + super().__init__(value, type=type, results=[result], loc=loc, ip=ip) class EraseOp: @@ -118,7 +118,7 @@ ip=None): type = type if type is None else _get_value(type) result = pdl.ValueType.get() - super().__init__(result, type=type, loc=loc, ip=ip) + super().__init__(type=type, results=[result], loc=loc, ip=ip) class OperandsOp: @@ -131,7 +131,7 @@ ip=None): types = types if types is None else _get_value(types) result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, type=types, loc=loc, ip=ip) + super().__init__(type=types, results=[result], loc=loc, ip=ip) class OperationOp: @@ -155,7 +155,7 @@ attributeNames = ArrayAttr.get(attributeNames) types = _get_values(types) result = pdl.OperationType.get() - super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip) + super().__init__(args, attributeValues, attributeNames, types, name=name, results=[result], loc=loc, ip=ip) class PatternOp: @@ -207,7 +207,7 @@ index = _get_int_attr(32, index) parent = _get_value(parent) result = pdl.ValueType.get() - super().__init__(result, parent, index, loc=loc, ip=ip) + super().__init__(parent, index, results=[result], loc=loc, ip=ip) class ResultsOp: @@ -222,7 +222,7 @@ ip=None): parent = _get_value(parent) index = index if index is None else _get_int_attr(32, index) - super().__init__(result, parent, index=index, loc=loc, ip=ip) + super().__init__(parent, index=index, results=[result], loc=loc, ip=ip) class RewriteOp: @@ -261,7 +261,7 @@ ip=None): type = type if type is None else _get_type_attr(type) result = pdl.TypeType.get() - super().__init__(result, type=type, loc=loc, ip=ip) + super().__init__(type=type, results=[result], loc=loc, ip=ip) class TypesOp: @@ -275,4 +275,4 @@ types = _get_array_attr([_get_type_attr(ty) for ty in types]) types = None if not types else types result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, types=types, loc=loc, ip=ip) + super().__init__(types=types, results=[result], loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -196,7 +196,7 @@ [AffineMapAttr.get(am) for am in indexing_maps]) generic_op = linalg.GenericOp( - result_tensors=result_types, + results=result_types, inputs=ins, outputs=outs, indexing_maps=indexing_maps_attr, @@ -342,18 +342,18 @@ operand_type = operand.type if _is_floating_point_type(operand_type): if is_unsigned_cast: - return arith.FPToUIOp(to_type, operand).result - return arith.FPToSIOp(to_type, operand).result + return arith.FPToUIOp(operand, results=[to_type]).result + return arith.FPToSIOp(operand, results=[to_type]).result if _is_index_type(operand_type): - return arith.IndexCastOp(to_type, operand).result + return arith.IndexCastOp(operand, results=[to_type]).result # Assume integer. from_width = IntegerType(operand_type).width if to_width > from_width: if is_unsigned_cast: - return arith.ExtUIOp(to_type, operand).result - return arith.ExtSIOp(to_type, operand).result + return arith.ExtUIOp(operand, results=[to_type]).result + return arith.ExtSIOp(operand, results=[to_type]).result elif to_width < from_width: - return arith.TruncIOp(to_type, operand).result + return arith.TruncIOp(operand, results=[to_type]).result raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") @@ -362,15 +362,15 @@ operand_type = operand.type if _is_integer_type(operand_type): if is_unsigned_cast: - return arith.UIToFPOp(to_type, operand).result - return arith.SIToFPOp(to_type, operand).result + return arith.UIToFPOp(operand, results=[to_type]).result + return arith.SIToFPOp(operand, results=[to_type]).result # Assume FloatType. to_width = _get_floating_point_width(to_type) from_width = _get_floating_point_width(operand_type) if to_width > from_width: - return arith.ExtFOp(to_type, operand).result + return arith.ExtFOp(operand, results=[to_type]).result elif to_width < from_width: - return arith.TruncFOp(to_type, operand).result + return arith.TruncFOp(operand, results=[to_type]).result raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -23,7 +23,6 @@ [AttrSizedOperandSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_results_or_values(variadic1)) @@ -65,14 +64,10 @@ // CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,] def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { - // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None, results): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: if variadic1 is not None: results.append(variadic1) - // CHECK: results.append(non_variadic) - // CHECK: results.append(variadic2) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -112,7 +107,6 @@ def AttributedOp : TestOp<"attributed_op"> { // CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: attributes["i32attr"] = i32attr @@ -154,7 +148,6 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { // CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) @@ -181,22 +174,21 @@ // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { - // CHECK: def __init__(self, type, *, loc=None, ip=None): + // CHECK: def __init__(self, type, *, loc=None, ip=None, results=[]): // CHECK: operands = [] - // CHECK: results = [] // CHECK: _ods_result_type_source_attr = attributes["type"] // CHECK: _ods_derived_result_type = ( // CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value // CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else // CHECK: _ods_result_type_source_attr.type) - // CHECK: results.extend([_ods_derived_result_type] * 2) + // CHECK: results=[_ods_derived_result_type] * 2 let arguments = (ins TypeAttr:$type); let results = (outs AnyType:$res, AnyType); } // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { - // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None): + // CHECK: def __init__(self, type, *, loc=None, ip=None, results): let arguments = (ins TypeAttr:$type); let results = (outs AnyType:$res, Variadic); } @@ -207,7 +199,6 @@ def EmptyOp : TestOp<"empty">; // CHECK: def __init__(self, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: _ods_successors = None @@ -218,9 +209,8 @@ // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { - // CHECK: def __init__(self, *, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None, results=[]): // CHECK: operands = [] - // CHECK: results = [] // CHECK: _ods_context = _ods_get_default_loc_context(loc) // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes( // CHECK: operands=operands, @@ -232,9 +222,8 @@ // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { - // CHECK: def __init__(self, *, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None, results=[]): // CHECK: operands = [] - // CHECK: results = [] // CHECK: _ods_context = _ods_get_default_loc_context(loc) // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes( // CHECK: operands=operands, @@ -248,17 +237,13 @@ // CHECK: class MissingNamesOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.missing_names" def MissingNamesOp : TestOp<"missing_names"> { - // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None): + // CHECK: def __init__(self, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None, results): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) // CHECK: operands.append(_get_op_result_or_value(f32)) // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) - // CHECK: results.append(i32) - // CHECK: results.append(_gen_res_1) - // CHECK: results.append(i64) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -288,7 +273,6 @@ let arguments = (ins AnyType:$non_optional, Optional:$optional); // CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(non_optional)) @@ -316,7 +300,6 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { // CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(non_variadic)) @@ -343,13 +326,10 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicResultOp : TestOp<"one_variadic_result"> { - // CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None, results): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: results.extend(variadic) - // CHECK: results.append(non_variadic) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -373,7 +353,6 @@ def PythonKeywordOp : TestOp<"python_keyword"> { // CHECK: def __init__(self, in_, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(in_)) @@ -389,18 +368,17 @@ } // CHECK-LABEL: OPERATION_NAME = "test.same_results" def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { - // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None): + // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None, results=[]): // CHECK: operands = [] - // CHECK: results = [] // CHECK: operands.append - // CHECK: results.extend([operands[0].type] * 1) + // CHECK: results=[operands[0].type] * 1 let arguments = (ins AnyType:$in1, AnyType:$in2); let results = (outs AnyType:$res); } // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { - // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None): + // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None, results): let arguments = (ins AnyType:$in1, AnyType:$in2); let results = (outs Variadic:$res); } @@ -456,15 +434,12 @@ // CHECK: class SimpleOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.simple" def SimpleOp : TestOp<"simple"> { - // CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None): + // CHECK: def __init__(self, i32, f32, *, loc=None, ip=None, results): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(i32)) // CHECK: operands.append(_get_op_result_or_value(f32)) - // CHECK: results.append(i64) - // CHECK: results.append(f64) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -494,7 +469,6 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { // CHECK: def __init__(self, num_variadic, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: _ods_successors = None @@ -518,7 +492,6 @@ def VariadicRegionOp : TestOp<"variadic_region"> { // CHECK: def __init__(self, num_variadic, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: _ods_successors = None diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -45,10 +45,10 @@ F32Type.get(), mask_type], [])) with InsertionPoint(f.add_entry_block()): A, zero, padding, mask = f.arguments - vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, - padding, mask=mask) - vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, - padding) + vector.TransferReadOp(A, [zero, zero], identity_map_attr, + padding, mask=mask, results=[vector_type]) + vector.TransferReadOp(A, [zero, zero], identity_map_attr, + padding, results=[vector_type]) func.ReturnOp([]) # CHECK: @transfer_read(%[[MEM:.*]]: memref, %[[IDX:.*]]: index, diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -490,7 +490,6 @@ constexpr const char *initTemplate = R"Py( def __init__(self, {0}): operands = [] - results = [] attributes = {{} regions = None {1} @@ -503,7 +502,6 @@ /// {0} is the field name. constexpr const char *singleOperandAppendTemplate = "operands.append(_get_op_result_or_value({0}))"; -constexpr const char *singleResultAppendTemplate = "results.append({0})"; /// Template for appending an optional element to the operand/result list. /// {0} is the field name. @@ -512,8 +510,6 @@ constexpr const char *optionalAppendAttrSizedOperandsTemplate = "operands.append(_get_op_result_or_value({0}) if {0} is not None else " "None)"; -constexpr const char *optionalAppendResultTemplate = - "if {0} is not None: results.append({0})"; /// Template for appending a list of elements to the operand/result list. /// {0} is the field name. @@ -521,7 +517,6 @@ "operands.extend(_get_op_results_or_values({0}))"; constexpr const char *multiOperandAppendPackTemplate = "operands.append(_get_op_results_or_values({0}))"; -constexpr const char *multiResultAppendTemplate = "results.extend({0})"; /// Template for setting an attribute in the operation builder. /// {0} is the attribute name; @@ -576,30 +571,6 @@ hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); } -/// Populates `builderArgs` with result names if the builder is expected to -/// accept them as arguments. -static void -populateBuilderArgsResults(const Operator &op, - llvm::SmallVectorImpl &builderArgs) { - if (canInferType(op)) - return; - - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - std::string name = op.getResultName(i).str(); - if (name.empty()) { - if (op.getNumResults() == 1) { - // Special case for one result, make the default name be 'result' - // to properly match the built-in result accessor. - name = "result"; - } else { - name = llvm::formatv("_gen_res_{0}", i); - } - } - name = sanitizeName(name); - builderArgs.push_back(name); - } -} - /// Populates `builderArgs` with the Python-compatible names of builder function /// arguments using intermixed attributes and operands in the same order as they /// appear in the `arguments` field of the op definition. Additionally, @@ -731,25 +702,25 @@ /// attribute: /// - {0} is the name of the attribute from which to derive the types. constexpr const char *deriveTypeFromAttrTemplate = - R"PY(_ods_result_type_source_attr = attributes["{0}"] -_ods_derived_result_type = ( + R"PY( _ods_result_type_source_attr = attributes["{0}"] + _ods_derived_result_type = ( _ods_ir.TypeAttr(_ods_result_type_source_attr).value if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else _ods_result_type_source_attr.type))PY"; /// Python code template appending {0} type {1} times to the results list. -constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; +constexpr const char *setSameResultsTemplate = " results=[{0}] * {1}"; /// Python code template for inferring the operation results using the /// corresponding interface: /// - {0} is the name of the class for which the types are inferred. constexpr const char *inferTypeInterfaceTemplate = - R"PY(_ods_context = _ods_get_default_loc_context(loc) -results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( - operands=operands, - attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), - context=_ods_context, - loc=loc) + R"PY( _ods_context = _ods_get_default_loc_context(loc) + results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( + operands=operands, + attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), + context=_ods_context, + loc=loc) )PY"; /// Appends the given multiline string as individual strings into @@ -768,13 +739,11 @@ /// builder to set up op results. static void populateBuilderLinesResult(const Operator &op, - llvm::ArrayRef names, llvm::SmallVectorImpl &builderLines) { - bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; - + builderLines.push_back("if not results:"); if (hasSameArgumentAndResultTypes(op)) { builderLines.push_back(llvm::formatv( - appendSameResultsTemplate, "operands[0].type", op.getNumResults())); + setSameResultsTemplate, "operands[0].type", op.getNumResults())); return; } @@ -785,7 +754,7 @@ appendLineByLine( llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), builderLines); - builderLines.push_back(llvm::formatv(appendSameResultsTemplate, + builderLines.push_back(llvm::formatv(setSameResultsTemplate, "_ods_derived_result_type", op.getNumResults())); return; @@ -797,31 +766,6 @@ builderLines); return; } - - // For each element, find or generate a name. - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - const NamedTypeConstraint &element = op.getResult(i); - std::string name = names[i]; - - // Choose the formatting string based on the element kind. - llvm::StringRef formatString; - if (!element.isVariableLength()) { - formatString = singleResultAppendTemplate; - } else if (element.isOptional()) { - formatString = optionalAppendResultTemplate; - } else { - assert(element.isVariadic() && "unhandled element group type"); - // If emitting with sizedSegments, then we add the actual list-typed - // element. Otherwise, we extend the actual operands. - if (sizedSegments) { - formatString = singleResultAppendTemplate; - } else { - formatString = multiResultAppendTemplate; - } - } - - builderLines.push_back(llvm::formatv(formatString.data(), name)); - } } /// If the operation has variadic regions, adds a builder argument to specify @@ -861,43 +805,34 @@ llvm::SmallVector successorArgNames; builderArgs.reserve(op.getNumOperands() + op.getNumResults() + op.getNumNativeAttributes() + op.getNumSuccessors()); - populateBuilderArgsResults(op, builderArgs); - size_t numResultArgs = builderArgs.size(); populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); - size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; + size_t numOperandAttrArgs = builderArgs.size(); populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); populateBuilderLinesOperand(op, operandArgNames, builderLines); - populateBuilderLinesAttr( - op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), - builderLines); - populateBuilderLinesResult( - op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), - builderLines); + populateBuilderLinesAttr(op, llvm::makeArrayRef(builderArgs), builderLines); populateBuilderLinesSuccessors(op, successorArgNames, builderLines); populateBuilderRegions(op, builderArgs, builderLines); // Layout of builderArgs vector elements: - // [ result_args operand_attr_args successor_args regions ] + // [ operand_attr_args successor_args regions ] // Determine whether the argument corresponding to a given index into the // builderArgs vector is a python keyword argument or not. auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { - // All result, successor, and region arguments are positional arguments. - if ((builderArgIndex < numResultArgs) || - (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) + // All successor, and region arguments are positional arguments. + if (builderArgIndex >= numOperandAttrArgs) return false; // Keyword arguments: // - optional named attributes (including unit attributes) // - default-valued named attributes // - optional operands - Argument a = op.getArg(builderArgIndex - numResultArgs); + Argument a = op.getArg(builderArgIndex); if (auto *nattr = a.dyn_cast()) return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); - else if (auto *ntype = a.dyn_cast()) + if (auto *ntype = a.dyn_cast()) return ntype->isOptional(); - else - return false; + return false; }; // StringRefs in functionArgs refer to strings allocated by builderArgs. @@ -923,6 +858,19 @@ } functionArgs.push_back("loc=None"); functionArgs.push_back("ip=None"); + + if (op.getNumResults()) { + if (canInferType(op)) { + builderArgs.push_back("results=[]"); + populateBuilderLinesResult(op, builderLines); + } else { + builderArgs.push_back("results"); + } + functionArgs.push_back(builderArgs.back()); + } else { + builderLines.push_back("results=[]"); + } + os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), llvm::join(builderLines, "\n ")); }