diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -19,7 +19,7 @@ """Specialization for the constant op class.""" def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): - super().__init__(result, value, loc=loc, ip=ip) + super().__init__(value, results=[result], loc=loc, ip=ip) @property def type(self): @@ -273,11 +273,11 @@ "to a function") super().__init__( - calleeOrResults.type.results, FlatSymbolRefAttr.get( calleeOrResults.name.value, context=_get_default_loc_context(loc)), argumentsOrCallee, + results=calleeOrResults.type.results, loc=loc, ip=ip) return @@ -289,12 +289,12 @@ if isinstance(argumentsOrCallee, FlatSymbolRefAttr): super().__init__( - calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) + argumentsOrCallee, arguments, results=calleeOrResults, loc=loc, ip=ip) elif isinstance(argumentsOrCallee, str): super().__init__( - calleeOrResults, FlatSymbolRefAttr.get( argumentsOrCallee, context=_get_default_loc_context(loc)), arguments, + results=calleeOrResults, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -34,9 +34,9 @@ ip=None, loc=None): super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), num_loops=_get_int64_attr(num_loops, default_value=1), + results=[pdl.OperationType.get()], ip=ip, loc=loc) @@ -51,10 +51,10 @@ ip=None, loc=None): super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), func_name=(func_name if isinstance(func_name, StringAttr) else StringAttr.get(func_name)), + results=[pdl.OperationType.get()], ip=ip, loc=loc) @@ -69,11 +69,11 @@ ip=None, loc=None): super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), fail_if_already_divisible=(fail_if_already_divisible if isinstance( fail_if_already_divisible, BoolAttr) else BoolAttr.get(fail_if_already_divisible)), + results=[pdl.OperationType.get()], ip=ip, loc=loc) @@ -89,10 +89,10 @@ ip=None, loc=None): super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), iteration_interval=_get_int64_attr(iteration_interval, default_value=1), read_latency=_get_int64_attr(read_latency, default_value=10), + results=[pdl.OperationType.get()], ip=ip, loc=loc) diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -34,4 +34,4 @@ indices_resolved = [] if indices is None else _get_op_results_or_values( indices) return_type = MemRefType(memref_resolved.type).element_type - super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) + super().__init__(memref, indices_resolved, results=[return_type], loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -79,7 +79,7 @@ ip=None): name = _get_str_attr(name) args = _get_values(args) - super().__init__(results, name, args, loc=loc, ip=ip) + super().__init__(name, args, results=results, loc=loc, ip=ip) class AttributeOp: @@ -93,7 +93,7 @@ ip=None): type = type if type is None else _get_value(type) result = pdl.AttributeType.get() - super().__init__(result, type=type, value=value, loc=loc, ip=ip) + super().__init__(value=value, type=type, results=[result], loc=loc, ip=ip) class EraseOp: @@ -118,7 +118,7 @@ ip=None): type = type if type is None else _get_value(type) result = pdl.ValueType.get() - super().__init__(result, type=type, loc=loc, ip=ip) + super().__init__(type=type, results=[result], loc=loc, ip=ip) class OperandsOp: @@ -131,7 +131,7 @@ ip=None): types = types if types is None else _get_value(types) result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, type=types, loc=loc, ip=ip) + super().__init__(type=types, results=[result], loc=loc, ip=ip) class OperationOp: @@ -155,7 +155,7 @@ attributeNames = ArrayAttr.get(attributeNames) types = _get_values(types) result = pdl.OperationType.get() - super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip) + super().__init__(args, attributeValues, attributeNames, types, name=name, results=[result], loc=loc, ip=ip) class PatternOp: @@ -207,7 +207,7 @@ index = _get_int_attr(32, index) parent = _get_value(parent) result = pdl.ValueType.get() - super().__init__(result, parent, index, loc=loc, ip=ip) + super().__init__(parent, index, results=[result], loc=loc, ip=ip) class ResultsOp: @@ -222,7 +222,7 @@ ip=None): parent = _get_value(parent) index = index if index is None else _get_int_attr(32, index) - super().__init__(result, parent, index=index, loc=loc, ip=ip) + super().__init__(parent, index=index, results=[result], loc=loc, ip=ip) class RewriteOp: @@ -261,7 +261,7 @@ ip=None): type = type if type is None else _get_type_attr(type) result = pdl.TypeType.get() - super().__init__(result, type=type, loc=loc, ip=ip) + super().__init__(type=type, results=[result], loc=loc, ip=ip) class TypesOp: @@ -275,4 +275,4 @@ types = _get_array_attr([_get_type_attr(ty) for ty in types]) types = None if not types else types result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, types=types, loc=loc, ip=ip) + super().__init__(types=types, results=[result], loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -74,8 +74,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), + results=[pdl.OperationType.get()], loc=loc, ip=ip) @@ -85,8 +85,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), + results=[pdl.OperationType.get()], loc=loc, ip=ip) @@ -103,9 +103,9 @@ pdl_operation_type = pdl.OperationType.get() interchange_attr = _get_int_array_attr(iterator_interchange) super().__init__( - pdl_operation_type, _get_op_result_or_value(target), iterator_interchange=interchange_attr, + results=[pdl_operation_type], loc=loc, ip=ip) @@ -132,13 +132,13 @@ hoist_paddings_attr = _get_int_array_attr(hoist_paddings) transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) super().__init__( - pdl_operation_type, _get_op_result_or_value(target), padding_values=padding_values_attr, padding_dimensions=padding_dimensions_attr, pack_paddings=pack_paddings_attr, hoist_paddings=hoist_paddings_attr, transpose_paddings=transpose_paddings_attr, + results=[pdl_operation_type], loc=loc, ip=ip) @@ -149,7 +149,7 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() super().__init__( - pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) + _get_op_result_or_value(target), results=[pdl_operation_type], loc=loc, ip=ip) class TileOp: @@ -167,10 +167,10 @@ num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) super().__init__( - pdl_operation_type, [pdl_operation_type] * num_loops, _get_op_result_or_value(target), sizes=sizes_attr, interchange=_get_int_array_attr(interchange) if interchange else None, + results=[pdl_operation_type] * (num_loops + 1), loc=loc, ip=ip) @@ -193,8 +193,8 @@ if isinstance(vectorize_padding, bool): vectorize_padding = BoolAttr.get(vectorize_padding) super().__init__( - pdl_operation_type, _get_op_result_or_value(target), vectorize_padding=vectorize_padding, + results=[pdl_operation_type], loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -22,8 +22,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), + results=[pdl.OperationType.get()], loc=loc, ip=ip) @@ -37,9 +37,9 @@ loc=None, ip=None): super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), _get_symbol_ref_attr(pattern_name), + results=[pdl.OperationType.get()], loc=loc, ip=ip) @@ -62,7 +62,7 @@ resultsOrRoot if not isinstance(resultsOrRoot, Sequence) else optionalRoot) root = _get_op_result_or_value(root) if root else None - super().__init__(results_=results, root=root) + super().__init__(results=results, root=root) self.regions[0].blocks.append(pdl.OperationType.get()) @property diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -197,7 +197,7 @@ [AffineMapAttr.get(am) for am in indexing_maps]) generic_op = linalg.GenericOp( - result_tensors=result_types, + results=result_types, inputs=ins, outputs=outs, indexing_maps=indexing_maps_attr, @@ -343,18 +343,18 @@ operand_type = operand.type if _is_floating_point_type(operand_type): if is_unsigned_cast: - return arith.FPToUIOp(to_type, operand).result - return arith.FPToSIOp(to_type, operand).result + return arith.FPToUIOp(operand, results=[to_type]).result + return arith.FPToSIOp(operand, results=[to_type]).result if _is_index_type(operand_type): - return arith.IndexCastOp(to_type, operand).result + return arith.IndexCastOp(operand, results=[to_type]).result # Assume integer. from_width = IntegerType(operand_type).width if to_width > from_width: if is_unsigned_cast: - return arith.ExtUIOp(to_type, operand).result - return arith.ExtSIOp(to_type, operand).result + return arith.ExtUIOp(operand, results=[to_type]).result + return arith.ExtSIOp(operand, results=[to_type]).result elif to_width < from_width: - return arith.TruncIOp(to_type, operand).result + return arith.TruncIOp(operand, results=[to_type]).result raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") @@ -363,15 +363,15 @@ operand_type = operand.type if _is_integer_type(operand_type): if is_unsigned_cast: - return arith.UIToFPOp(to_type, operand).result - return arith.SIToFPOp(to_type, operand).result + return arith.UIToFPOp(operand, results=[to_type]).result + return arith.SIToFPOp(operand, results=[to_type]).result # Assume FloatType. to_width = _get_floating_point_width(to_type) from_width = _get_floating_point_width(operand_type) if to_width > from_width: - return arith.ExtFOp(to_type, operand).result + return arith.ExtFOp(operand, results=[to_type]).result elif to_width < from_width: - return arith.TruncFOp(to_type, operand).result + return arith.TruncFOp(operand, results=[to_type]).result raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -23,7 +23,6 @@ [AttrSizedOperandSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, *, variadic2=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_results_or_values(variadic1)) @@ -65,14 +64,10 @@ // CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,] def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { - // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None, results): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: if variadic1 is not None: results.append(variadic1) - // CHECK: results.append(non_variadic) - // CHECK: results.append(variadic2) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -112,7 +107,6 @@ def AttributedOp : TestOp<"attributed_op"> { // CHECK: def __init__(self, i32attr, in_, *, optionalF32Attr=None, unitAttr=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: attributes["i32attr"] = i32attr @@ -154,7 +148,6 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { // CHECK: def __init__(self, _gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) @@ -185,12 +178,12 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> { // CHECK: def __init__(self, *, arr=None, unsupported=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: attributes["arr"] = arr if arr is not None else _ods_ir.ArrayAttr.get([]) // CHECK: unsupported is not None, "attribute unsupported must be specified" // CHECK: _ods_successors = None + // CHECK: results = [] // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) @@ -202,22 +195,21 @@ // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { - // CHECK: def __init__(self, type, *, loc=None, ip=None): + // CHECK: def __init__(self, type, *, loc=None, ip=None, results=[]): // CHECK: operands = [] - // CHECK: results = [] // CHECK: _ods_result_type_source_attr = attributes["type"] // CHECK: _ods_derived_result_type = ( // CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value // CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else // CHECK: _ods_result_type_source_attr.type) - // CHECK: results.extend([_ods_derived_result_type] * 2) + // CHECK: results = [_ods_derived_result_type] * 2 let arguments = (ins TypeAttr:$type); let results = (outs AnyType:$res, AnyType); } // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { - // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None): + // CHECK: def __init__(self, type, *, loc=None, ip=None, results): let arguments = (ins TypeAttr:$type); let results = (outs AnyType:$res, Variadic); } @@ -228,7 +220,6 @@ def EmptyOp : TestOp<"empty">; // CHECK: def __init__(self, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: _ods_successors = None @@ -239,9 +230,8 @@ // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { - // CHECK: def __init__(self, *, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None, results=[]): // CHECK: operands = [] - // CHECK: results = [] // CHECK: _ods_context = _ods_get_default_loc_context(loc) // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes( // CHECK: operands=operands, @@ -253,9 +243,8 @@ // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { - // CHECK: def __init__(self, *, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None, results=[]): // CHECK: operands = [] - // CHECK: results = [] // CHECK: _ods_context = _ods_get_default_loc_context(loc) // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes( // CHECK: operands=operands, @@ -269,17 +258,13 @@ // CHECK: class MissingNamesOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.missing_names" def MissingNamesOp : TestOp<"missing_names"> { - // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None): + // CHECK: def __init__(self, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None, results): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) // CHECK: operands.append(_get_op_result_or_value(f32)) // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) - // CHECK: results.append(i32) - // CHECK: results.append(_gen_res_1) - // CHECK: results.append(i64) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -309,7 +294,6 @@ let arguments = (ins AnyType:$non_optional, Optional:$optional); // CHECK: def __init__(self, non_optional, *, optional=None, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(non_optional)) @@ -337,7 +321,6 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { // CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(non_variadic)) @@ -364,13 +347,10 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicResultOp : TestOp<"one_variadic_result"> { - // CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None, results): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: results.extend(variadic) - // CHECK: results.append(non_variadic) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -394,7 +374,6 @@ def PythonKeywordOp : TestOp<"python_keyword"> { // CHECK: def __init__(self, in_, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(in_)) @@ -410,18 +389,17 @@ } // CHECK-LABEL: OPERATION_NAME = "test.same_results" def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { - // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None): + // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None, results=[]): // CHECK: operands = [] - // CHECK: results = [] // CHECK: operands.append - // CHECK: results.extend([operands[0].type] * 1) + // CHECK: results = [operands[0].type] * 1 let arguments = (ins AnyType:$in1, AnyType:$in2); let results = (outs AnyType:$res); } // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { - // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None): + // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None, results): let arguments = (ins AnyType:$in1, AnyType:$in2); let results = (outs Variadic:$res); } @@ -477,15 +455,12 @@ // CHECK: class SimpleOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.simple" def SimpleOp : TestOp<"simple"> { - // CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None): + // CHECK: def __init__(self, i32, f32, *, loc=None, ip=None, results): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: operands.append(_get_op_result_or_value(i32)) // CHECK: operands.append(_get_op_result_or_value(f32)) - // CHECK: results.append(i64) - // CHECK: results.append(f64) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -515,7 +490,6 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { // CHECK: def __init__(self, num_variadic, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: _ods_successors = None @@ -539,7 +513,6 @@ def VariadicRegionOp : TestOp<"variadic_region"> { // CHECK: def __init__(self, num_variadic, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None // CHECK: _ods_successors = None diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s | tee /tmp/t1 | FileCheck %s from mlir.ir import * from mlir.dialects import transform diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -45,10 +45,10 @@ F32Type.get(), mask_type], [])) with InsertionPoint(f.add_entry_block()): A, zero, padding, mask = f.arguments - vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, - padding, mask=mask) - vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, - padding) + vector.TransferReadOp(A, [zero, zero], identity_map_attr, + padding, mask=mask, results=[vector_type]) + vector.TransferReadOp(A, [zero, zero], identity_map_attr, + padding, results=[vector_type]) func.ReturnOp([]) # CHECK: @transfer_read(%[[MEM:.*]]: memref, %[[IDX:.*]]: index, diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -499,7 +499,6 @@ constexpr const char *initTemplate = R"Py( def __init__(self, {0}): operands = [] - results = [] attributes = {{} regions = None {1} @@ -512,7 +511,6 @@ /// {0} is the field name. constexpr const char *singleOperandAppendTemplate = "operands.append(_get_op_result_or_value({0}))"; -constexpr const char *singleResultAppendTemplate = "results.append({0})"; /// Template for appending an optional element to the operand/result list. /// {0} is the field name. @@ -521,8 +519,6 @@ constexpr const char *optionalAppendAttrSizedOperandsTemplate = "operands.append(_get_op_result_or_value({0}) if {0} is not None else " "None)"; -constexpr const char *optionalAppendResultTemplate = - "if {0} is not None: results.append({0})"; /// Template for appending a list of elements to the operand/result list. /// {0} is the field name. @@ -530,7 +526,6 @@ "operands.extend(_get_op_results_or_values({0}))"; constexpr const char *multiOperandAppendPackTemplate = "operands.append(_get_op_results_or_values({0}))"; -constexpr const char *multiResultAppendTemplate = "results.extend({0})"; /// Template for setting an attribute in the operation builder. /// {0} is the attribute name; @@ -600,30 +595,6 @@ hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); } -/// Populates `builderArgs` with result names if the builder is expected to -/// accept them as arguments. -static void -populateBuilderArgsResults(const Operator &op, - llvm::SmallVectorImpl &builderArgs) { - if (canInferType(op)) - return; - - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - std::string name = op.getResultName(i).str(); - if (name.empty()) { - if (op.getNumResults() == 1) { - // Special case for one result, make the default name be 'result' - // to properly match the built-in result accessor. - name = "result"; - } else { - name = llvm::formatv("_gen_res_{0}", i); - } - } - name = sanitizeName(name); - builderArgs.push_back(name); - } -} - /// Populates `builderArgs` with the Python-compatible names of builder function /// arguments using intermixed attributes and operands in the same order as they /// appear in the `arguments` field of the op definition. Additionally, @@ -789,25 +760,25 @@ /// attribute: /// - {0} is the name of the attribute from which to derive the types. constexpr const char *deriveTypeFromAttrTemplate = - R"PY(_ods_result_type_source_attr = attributes["{0}"] -_ods_derived_result_type = ( + R"PY( _ods_result_type_source_attr = attributes["{0}"] + _ods_derived_result_type = ( _ods_ir.TypeAttr(_ods_result_type_source_attr).value if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else _ods_result_type_source_attr.type))PY"; /// Python code template appending {0} type {1} times to the results list. -constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; +constexpr const char *setSameResultsTemplate = " results = [{0}] * {1}"; /// Python code template for inferring the operation results using the /// corresponding interface: /// - {0} is the name of the class for which the types are inferred. constexpr const char *inferTypeInterfaceTemplate = - R"PY(_ods_context = _ods_get_default_loc_context(loc) -results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( - operands=operands, - attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), - context=_ods_context, - loc=loc) + R"PY( _ods_context = _ods_get_default_loc_context(loc) + results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( + operands=operands, + attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), + context=_ods_context, + loc=loc) )PY"; /// Appends the given multiline string as individual strings into @@ -826,13 +797,11 @@ /// builder to set up op results. static void populateBuilderLinesResult(const Operator &op, - llvm::ArrayRef names, llvm::SmallVectorImpl &builderLines) { - bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; - + builderLines.push_back("if not results:"); if (hasSameArgumentAndResultTypes(op)) { builderLines.push_back(llvm::formatv( - appendSameResultsTemplate, "operands[0].type", op.getNumResults())); + setSameResultsTemplate, "operands[0].type", op.getNumResults())); return; } @@ -843,7 +812,7 @@ appendLineByLine( llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), builderLines); - builderLines.push_back(llvm::formatv(appendSameResultsTemplate, + builderLines.push_back(llvm::formatv(setSameResultsTemplate, "_ods_derived_result_type", op.getNumResults())); return; @@ -855,31 +824,6 @@ builderLines); return; } - - // For each element, find or generate a name. - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - const NamedTypeConstraint &element = op.getResult(i); - std::string name = names[i]; - - // Choose the formatting string based on the element kind. - llvm::StringRef formatString; - if (!element.isVariableLength()) { - formatString = singleResultAppendTemplate; - } else if (element.isOptional()) { - formatString = optionalAppendResultTemplate; - } else { - assert(element.isVariadic() && "unhandled element group type"); - // If emitting with sizedSegments, then we add the actual list-typed - // element. Otherwise, we extend the actual operands. - if (sizedSegments) { - formatString = singleResultAppendTemplate; - } else { - formatString = multiResultAppendTemplate; - } - } - - builderLines.push_back(llvm::formatv(formatString.data(), name)); - } } /// If the operation has variadic regions, adds a builder argument to specify @@ -919,37 +863,29 @@ llvm::SmallVector successorArgNames; builderArgs.reserve(op.getNumOperands() + op.getNumResults() + op.getNumNativeAttributes() + op.getNumSuccessors()); - populateBuilderArgsResults(op, builderArgs); - size_t numResultArgs = builderArgs.size(); populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); - size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; + size_t numOperandAttrArgs = builderArgs.size(); populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); populateBuilderLinesOperand(op, operandArgNames, builderLines); - populateBuilderLinesAttr( - op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), - builderLines); - populateBuilderLinesResult( - op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), - builderLines); + populateBuilderLinesAttr(op, llvm::makeArrayRef(builderArgs), builderLines); populateBuilderLinesSuccessors(op, successorArgNames, builderLines); populateBuilderRegions(op, builderArgs, builderLines); // Layout of builderArgs vector elements: - // [ result_args operand_attr_args successor_args regions ] + // [ operand_attr_args successor_args regions ] // Determine whether the argument corresponding to a given index into the // builderArgs vector is a python keyword argument or not. auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { - // All result, successor, and region arguments are positional arguments. - if ((builderArgIndex < numResultArgs) || - (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) + // All successor, and region arguments are positional arguments. + if (builderArgIndex >= numOperandAttrArgs) return false; // Keyword arguments: // - optional named attributes (including unit attributes) // - default-valued named attributes // - optional operands - Argument a = op.getArg(builderArgIndex - numResultArgs); + Argument a = op.getArg(builderArgIndex); if (auto *nattr = a.dyn_cast()) return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); if (auto *ntype = a.dyn_cast()) @@ -980,6 +916,19 @@ } functionArgs.push_back("loc=None"); functionArgs.push_back("ip=None"); + + if (op.getNumResults()) { + if (canInferType(op)) { + builderArgs.push_back("results=[]"); + populateBuilderLinesResult(op, builderLines); + } else { + builderArgs.push_back("results"); + } + functionArgs.push_back(builderArgs.back()); + } else { + builderLines.push_back("results = []"); + } + os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), llvm::join(builderLines, "\n ")); }