diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -138,6 +138,9 @@ // Op attribute accessors. NamedAttribute &getAttribute(int index) { return attributes[index]; } + const NamedAttribute &getAttribute(int index) const { + return attributes[index]; + } // Op operand iterators. value_iterator operand_begin(); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -601,6 +601,71 @@ llvm::Optional refOperation; PyBlock block; }; +/// Wrapper around the generic MlirType. +/// The lifetime of a type is bound by the PyContext that created it. +class PyType : public BaseContextObject { +public: + PyType(PyMlirContextRef contextRef, MlirType type) + : BaseContextObject(std::move(contextRef)), type(type) {} + bool operator==(const PyType &other); + operator MlirType() const { return type; } + MlirType get() const { return type; } + + /// Gets a capsule wrapping the void* within the MlirType. + pybind11::object getCapsule(); + + /// Creates a PyType from the MlirType wrapped by a capsule. + /// Note that PyType instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirType + /// is taken by calling this function. + static PyType createFromCapsule(pybind11::object capsule); + +private: + MlirType type; +}; + +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = pybind11::class_; + using IsAFunctionTy = bool (*)(MlirType); + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. @@ -685,71 +750,8 @@ cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool { return DerivedTy::isaFunction(otherAttr); }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -/// Wrapper around the generic MlirType. -/// The lifetime of a type is bound by the PyContext that created it. -class PyType : public BaseContextObject { -public: - PyType(PyMlirContextRef contextRef, MlirType type) - : BaseContextObject(std::move(contextRef)), type(type) {} - bool operator==(const PyType &other); - operator MlirType() const { return type; } - MlirType get() const { return type; } - - /// Gets a capsule wrapping the void* within the MlirType. - pybind11::object getCapsule(); - - /// Creates a PyType from the MlirType wrapped by a capsule. - /// Note that PyType instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirType - /// is taken by calling this function. - static PyType createFromCapsule(pybind11::object capsule); - -private: - MlirType type; -}; - -/// CRTP base classes for Python types that subclass Type and should be -/// castable from it (i.e. via something like IntegerType(t)). -/// By default, type class hierarchies are one level deep (i.e. a -/// concrete type class extends PyType); however, intermediate python-visible -/// base classes can be modeled by specifying a BaseTy. -template -class PyConcreteType : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = pybind11::class_; - using IsAFunctionTy = bool (*)(MlirType); - - PyConcreteType() = default; - PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} - PyConcreteType(PyType &orig) - : PyConcreteType(orig.getContext(), castFrom(orig)) {} - - static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); + cls.def_property_readonly("type", [](PyAttribute &attr) { + return PyType(attr.getContext(), mlirAttributeGetType(attr)); }); DerivedTy::bindDerived(cls); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -221,7 +221,7 @@ elif expr.scalar_index: dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) - return linalg.IndexOp(IndexType.get(), dim_attr).result + return linalg.IndexOp(dim_attr).result elif expr.scalar_apply: try: fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") @@ -303,61 +303,61 @@ def _eval_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.AddFOp(lhs.type, lhs, rhs).result + return arith.AddFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.AddIOp(lhs.type, lhs, rhs).result + return arith.AddIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") def _eval_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): - return math.ExpOp(x.type, x).result + return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") def _eval_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): - return math.LogOp(x.type, x).result + return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") def _eval_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.SubFOp(lhs.type, lhs, rhs).result + return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.SubIOp(lhs.type, lhs, rhs).result + return arith.SubIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operand: {lhs}") def _eval_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.MulFOp(lhs.type, lhs, rhs).result + return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MulIOp(lhs.type, lhs, rhs).result + return arith.MulIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") def _eval_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs.type, lhs, rhs).result + return std.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxSIOp(lhs.type, lhs, rhs).result + return std.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs.type, lhs, rhs).result + return std.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxUIOp(lhs.type, lhs, rhs).result + return std.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs.type, lhs, rhs).result + return std.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinSIOp(lhs.type, lhs, rhs).result + return std.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs.type, lhs, rhs).result + return std.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinUIOp(lhs.type, lhs, rhs).result + return std.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -1,6 +1,7 @@ // RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Bindings/Python/Attributes.td" // CHECK: @_ods_cext.register_dialect @@ -176,6 +177,27 @@ let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr:$is); } +// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" +def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { + // CHECK: def __init__(self, type, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: _ods_result_type_source_attr = attributes["type"] + // CHECK: _ods_derived_result_type = ( + // CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value + // CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else + // CHECK: _ods_result_type_source_attr.type) + // CHECK: results.extend([_ods_derived_result_type] * 2) + let arguments = (ins TypeAttr:$type); + let results = (outs AnyType:$res, AnyType); +} + +// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" +def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { + // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None): + let arguments = (ins TypeAttr:$type); + let results = (outs AnyType:$res, Variadic); +} // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class EmptyOp(_ods_ir.OpView): @@ -191,6 +213,35 @@ // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + +// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" +def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { + // CHECK: def __init__(self, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: _ods_context = _ods_get_default_loc_context(loc) + // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes( + // CHECK: operands=operands, + // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), + // CHECK: context=_ods_context, + // CHECK: loc=loc) + let results = (outs I32:$i32, F32:$f32); +} + +// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" +def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { + // CHECK: def __init__(self, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: _ods_context = _ods_get_default_loc_context(loc) + // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes( + // CHECK: operands=operands, + // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), + // CHECK: context=_ods_context, + // CHECK: loc=loc) + let results = (outs AnyType, AnyType, AnyType); +} + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.missing_names" @@ -200,12 +251,12 @@ // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: results.append(i32) - // CHECK: results.append(_gen_res_1) - // CHECK: results.append(i64) // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0)) // CHECK: operands.append(_get_op_result_or_value(f32)) // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2)) + // CHECK: results.append(i32) + // CHECK: results.append(_gen_res_1) + // CHECK: results.append(i64) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -223,7 +274,7 @@ // CHECK: @builtins.property // CHECK: def i64(self): // CHECK: return self.operation.results[2] - let results = (outs I32:$i32, F32, I64:$i64); + let results = (outs I32:$i32, AnyFloat, I64:$i64); } // CHECK: @_ods_cext.register_operation(_Dialect) @@ -305,6 +356,24 @@ // CHECK: return self.operation.operands[0] let arguments = (ins AnyType:$in); } +// CHECK-LABEL: OPERATION_NAME = "test.same_results" +def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { + // CHECK: def __init__(self, in1, in2, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: operands.append + // CHECK: results.extend([operands[0].type] * 1) + let arguments = (ins AnyType:$in1, AnyType:$in2); + let results = (outs AnyType:$res); +} + +// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" +def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { + // CHECK: def __init__(self, res, in1, in2, *, loc=None, ip=None): + let arguments = (ins AnyType:$in1, AnyType:$in2); + let results = (outs Variadic:$res); +} + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SameVariadicOperandSizeOp(_ods_ir.OpView): @@ -361,10 +430,10 @@ // CHECK: results = [] // CHECK: attributes = {} // CHECK: regions = None - // CHECK: results.append(i64) - // CHECK: results.append(f64) // CHECK: operands.append(_get_op_result_or_value(i32)) // CHECK: operands.append(_get_op_result_or_value(f32)) + // CHECK: results.append(i64) + // CHECK: results.append(f64) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -386,7 +455,7 @@ // CHECK: @builtins.property // CHECK: def f64(self): // CHECK: return self.operation.results[1] - let results = (outs I64:$i64, F64:$f64); + let results = (outs I64:$i64, AnyFloat:$f64); } // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView): diff --git a/mlir/test/python/dialects/math.py b/mlir/test/python/dialects/math.py --- a/mlir/test/python/dialects/math.py +++ b/mlir/test/python/dialects/math.py @@ -16,7 +16,7 @@ with InsertionPoint(module.body): @builtin.FuncOp.from_py_func(F32Type.get()) def emit_sqrt(arg): - return mlir_math.SqrtOp(F32Type.get(), arg) + return mlir_math.SqrtOp(arg) # CHECK-LABEL: func @emit_sqrt( # CHECK-SAME: %[[ARG:.*]]: f32) { diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -137,8 +137,7 @@ test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): - op = test.InferResultsOp( - IntegerType.get_signless(32), IntegerType.get_signless(64)) + op = test.InferResultsOp() dummy = test.DummyOp() # CHECK: [Type(i32), Type(i64)] @@ -173,3 +172,38 @@ pass else: assert False, "not expected dummy op class to implement the interface" + + +# CHECK-LABEL: TEST: resultTypesDefinedByTraits +@run +def resultTypesDefinedByTraits(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + inferred = test.InferResultsOp() + same = test.SameOperandAndResultTypeOp([inferred.results[0]]) + # CHECK-COUNT-2: i32 + print(same.one.type) + print(same.two.type) + + first_type_attr = test.FirstAttrDeriveTypeAttrOp( + inferred.results[1], TypeAttr.get(IndexType.get())) + # CHECK-COUNT-2: index + print(first_type_attr.one.type) + print(first_type_attr.two.type) + + first_attr = test.FirstAttrDeriveAttrOp( + FloatAttr.get(F32Type.get(), 3.14)) + # CHECK-COUNT-3: f32 + print(first_attr.one.type) + print(first_attr.two.type) + print(first_attr.three.type) + + implied = test.InferResultsImpliedOp() + # CHECK: i32 + print(implied.integer.type) + # CHECK: f64 + print(implied.flt.type) + # CHECK: index + print(implied.index.type) diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py --- a/mlir/test/python/dialects/shape.py +++ b/mlir/test/python/dialects/shape.py @@ -18,15 +18,12 @@ with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() - indexT = IndexType.get() with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( RankedTensorType.get((12, -1), f32)) def const_shape_tensor(arg): - return shape.ConstShapeOp(RankedTensorType.get((2,), indexT), - DenseElementsAttr.get(np.array([10, 20]))) + return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20]))) # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) # CHECK: shape.const_shape [10, 20] : tensor<2xindex> print(module) - diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py --- a/mlir/test/python/ir/dialects.py +++ b/mlir/test/python/ir/dialects.py @@ -82,11 +82,11 @@ # Create via dialects context collection. input1 = createInput() input2 = createInput() - op1 = ctx.dialects.arith.AddFOp(input1.type, input1, input2) + op1 = ctx.dialects.arith.AddFOp(input1, input2) # Create via an import from mlir.dialects.arith import AddFOp - AddFOp(input1.type, input1, op1.result) + AddFOp(input1, op1.result) # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -52,4 +52,28 @@ }]; } +// If all result types are buildable, the InferTypeOpInterface is implied and is +// autogenerated by C++ ODS. +def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> { + let results = (outs I32:$integer, F64:$flt, Index:$index); +} + +def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op", + [SameOperandsAndResultType]> { + let arguments = (ins Variadic); + let results = (outs AnyType:$one, AnyType:$two); +} + +def FirstAttrDeriveTypeAttrOp : TestOp<"first_attr_derive_type_attr_op", + [FirstAttrDerivedResultType]> { + let arguments = (ins AnyType:$input, TypeAttr:$type); + let results = (outs AnyType:$one, AnyType:$two); +} + +def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op", + [FirstAttrDerivedResultType]> { + let arguments = (ins AnyAttr:$iattr); + let results = (outs AnyType:$one, AnyType:$two, AnyType:$three); +} + #endif // PYTHON_TEST_OPS diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -541,16 +541,42 @@ /// {1} is the value to add. constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; -/// Populates `builderArgs` with the Python-compatible names of builder function -/// arguments, first the results, then the intermixed attributes and operands in -/// the same order as they appear in the `arguments` field of the op definition. -/// Additionally, `operandNames` is populated with names of operands in their -/// order of appearance. +/// Returns true if the SameArgumentAndResultTypes trait can be used to infer +/// result types of the given operation. +static bool hasSameArgumentAndResultTypes(const Operator &op) { + return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && + op.getNumVariableLengthResults() == 0; +} + +/// Returns true if the FirstAttrDerivedResultType trait can be used to infer +/// result types of the given operation. +static bool hasFirstAttrDerivedResultTypes(const Operator &op) { + return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && + op.getNumVariableLengthResults() == 0; +} + +/// Returns true if the InferTypeOpInterface can be used to infer result types +/// of the given operation. +static bool hasInferTypeInterface(const Operator &op) { + return op.getTrait("::mlir::InferTypeOpInterface::Trait") && + op.getNumRegions() == 0; +} + +/// Returns true if there is a trait or interface that can be used to infer +/// result types of the given operation. +static bool canInferType(const Operator &op) { + return hasSameArgumentAndResultTypes(op) || + hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); +} + +/// Populates `builderArgs` with result names if the builder is expected to +/// accept them as arguments. static void -populateBuilderArgs(const Operator &op, - llvm::SmallVectorImpl &builderArgs, - llvm::SmallVectorImpl &operandNames, - llvm::SmallVectorImpl &successorArgNames) { +populateBuilderArgsResults(const Operator &op, + llvm::SmallVectorImpl &builderArgs) { + if (canInferType(op)) + return; + for (int i = 0, e = op.getNumResults(); i < e; ++i) { std::string name = op.getResultName(i).str(); if (name.empty()) { @@ -565,6 +591,19 @@ name = sanitizeName(name); builderArgs.push_back(name); } +} + +/// Populates `builderArgs` with the Python-compatible names of builder function +/// arguments using intermixed attributes and operands in the same order as they +/// appear in the `arguments` field of the op definition. Additionally, +/// `operandNames` is populated with names of operands in their order of +/// appearance. +static void +populateBuilderArgs(const Operator &op, + llvm::SmallVectorImpl &builderArgs, + llvm::SmallVectorImpl &operandNames, + llvm::SmallVectorImpl &successorArgNames) { + for (int i = 0, e = op.getNumArgs(); i < e; ++i) { std::string name = op.getArgName(i).str(); if (name.empty()) @@ -670,6 +709,43 @@ } } +/// Python code template for deriving the operation result types from its +/// attribute: +/// - {0} is the name of the attribute from which to derive the types. +constexpr const char *deriveTypeFromAttrTemplate = + R"PY(_ods_result_type_source_attr = attributes["{0}"] +_ods_derived_result_type = ( + _ods_ir.TypeAttr(_ods_result_type_source_attr).value + if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else + _ods_result_type_source_attr.type))PY"; + +/// Python code template appending {0} type {1} times to the results list. +constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; + +/// Python code template for inferring the operation results using the +/// corresponding interface: +/// - {0} is the name of the class for which the types are inferred. +constexpr const char *inferTypeInterfaceTemplate = + R"PY(_ods_context = _ods_get_default_loc_context(loc) +results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( + operands=operands, + attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), + context=_ods_context, + loc=loc) +)PY"; + +/// Appends the given multiline string as individual strings into +/// `builderLines`. +static void appendLineByLine(StringRef string, + llvm::SmallVectorImpl &builderLines) { + + std::pair split = std::make_pair(string, string); + do { + split = split.second.split('\n'); + builderLines.push_back(split.first.str()); + } while (!split.second.empty()); +} + /// Populates `builderLines` with additional lines that are required in the /// builder to set up op results. static void @@ -678,6 +754,32 @@ llvm::SmallVectorImpl &builderLines) { bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; + if (hasSameArgumentAndResultTypes(op)) { + builderLines.push_back(llvm::formatv( + appendSameResultsTemplate, "operands[0].type", op.getNumResults())); + return; + } + + if (hasFirstAttrDerivedResultTypes(op)) { + const NamedAttribute &firstAttr = op.getAttribute(0); + assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " + "from which the type is derived"); + appendLineByLine( + llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), + builderLines); + builderLines.push_back(llvm::formatv(appendSameResultsTemplate, + "_ods_derived_result_type", + op.getNumResults())); + return; + } + + if (hasInferTypeInterface(op)) { + appendLineByLine( + llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), + builderLines); + return; + } + // For each element, find or generate a name. for (int i = 0, e = op.getNumResults(); i < e; ++i) { const NamedTypeConstraint &element = op.getResult(i); @@ -741,14 +843,16 @@ llvm::SmallVector successorArgNames; builderArgs.reserve(op.getNumOperands() + op.getNumResults() + op.getNumNativeAttributes() + op.getNumSuccessors()); + populateBuilderArgsResults(op, builderArgs); + size_t numResultArgs = builderArgs.size(); populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); - populateBuilderLinesResult( - op, llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), - builderLines); populateBuilderLinesOperand(op, operandArgNames, builderLines); populateBuilderLinesAttr( - op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), + op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), + builderLines); + populateBuilderLinesResult( + op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), builderLines); populateBuilderLinesSuccessors(op, successorArgNames, builderLines); populateBuilderRegions(op, builderArgs, builderLines);