diff --git a/mlir/include/mlir-c/StandardTypes.h b/mlir/include/mlir-c/StandardTypes.h --- a/mlir/include/mlir-c/StandardTypes.h +++ b/mlir/include/mlir-c/StandardTypes.h @@ -270,6 +270,30 @@ /** Returns the pos-th type in the tuple type. */ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos); +/*============================================================================*/ +/* Function type. */ +/*============================================================================*/ + +/** Checks whether the given type is a function type. */ +int mlirTypeIsAFunction(MlirType type); + +/** Creates a function type, mapping a list of input types to result types. */ +MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, + MlirType *inputs, intptr_t numResults, + MlirType *results); + +/** Returns the number of input types. */ +intptr_t mlirFunctionTypeGetNumInputs(MlirType type); + +/** Returns the number of result types. */ +intptr_t mlirFunctionTypeGetNumResults(MlirType type); + +/** Returns the pos-th input type. */ +MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos); + +/** Returns the pos-th result type. */ +MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1278,6 +1278,56 @@ } }; +/// Function type. +class PyFunctionType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr const char *pyClassName = "FunctionType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyMlirContext &context, std::vector inputs, + std::vector results) { + SmallVector inputsRaw(inputs.begin(), inputs.end()); + SmallVector resultsRaw(results.begin(), results.end()); + MlirType t = mlirFunctionTypeGet(context.get(), inputsRaw.size(), + inputsRaw.data(), resultsRaw.size(), + resultsRaw.data()); + return PyFunctionType(context.getRef(), t); + }, + py::arg("context"), py::arg("inputs"), py::arg("results"), + "Gets a FunctionType from a list of input and result types"); + c.def_property_readonly( + "inputs", + [](PyFunctionType &self) { + MlirType t = self.type; + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self.type); + i < e; ++i) { + types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + c.def_property_readonly( + "results", + [](PyFunctionType &self) { + MlirType t = self.type; + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self.type); + i < e; ++i) { + types.append(PyType(contextRef, mlirFunctionTypeGetResult(t, i))); + } + return types; + }, + "Returns the list of result types in the FunctionType."); + } +}; + } // namespace //------------------------------------------------------------------------------ @@ -1613,6 +1663,7 @@ PyMemRefType::bind(m); PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); + PyFunctionType::bind(m); // Container bindings. PyBlockIterator::bind(m); diff --git a/mlir/lib/CAPI/IR/StandardTypes.cpp b/mlir/lib/CAPI/IR/StandardTypes.cpp --- a/mlir/lib/CAPI/IR/StandardTypes.cpp +++ b/mlir/lib/CAPI/IR/StandardTypes.cpp @@ -13,6 +13,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" using namespace mlir; @@ -297,3 +298,41 @@ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { return wrap(unwrap(type).cast().getType(static_cast(pos))); } + +/*============================================================================*/ +/* Function type. */ +/*============================================================================*/ + +int mlirTypeIsAFunction(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, + MlirType *inputs, intptr_t numResults, + MlirType *results) { + SmallVector inputsList; + SmallVector resultsList; + (void)unwrapList(numInputs, inputs, inputsList); + (void)unwrapList(numResults, results, resultsList); + return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx))); +} + +intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { + return unwrap(type).cast().getNumInputs(); +} + +intptr_t mlirFunctionTypeGetNumResults(MlirType type) { + return unwrap(type).cast().getNumResults(); +} + +MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) { + assert(pos >= 0 && "pos in array must be positive"); + return wrap( + unwrap(type).cast().getInput(static_cast(pos))); +} + +MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { + assert(pos >= 0 && "pos in array must be positive"); + return wrap( + unwrap(type).cast().getResult(static_cast(pos))); +} diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -392,3 +392,19 @@ print("pos-th type in the tuple type:", tuple_type.get_type(1)) run(testTupleType) + + +# CHECK-LABEL: TEST: testFunctionType +def testFunctionType(): + ctx = mlir.ir.Context() + input_types = [mlir.ir.IntegerType.get_signless(ctx, 32), + mlir.ir.IntegerType.get_signless(ctx, 16)] + result_types = [mlir.ir.IndexType(ctx)] + func = mlir.ir.FunctionType.get(ctx, input_types, result_types) + # CHECK: INPUTS: [Type(i32), Type(i16)] + print("INPUTS:", func.inputs) + # CHECK: RESULTS: [Type(index)] + print("RESULTS:", func.results) + + +run(testFunctionType) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -10,8 +10,8 @@ /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s */ -#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir-c/AffineMap.h" #include "mlir-c/Registration.h" #include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardTypes.h" @@ -443,6 +443,26 @@ mlirTypeDump(tuple); fprintf(stderr, "\n"); + // Function type. + MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)}; + MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16), + mlirIntegerTypeGet(ctx, 32), + mlirIntegerTypeGet(ctx, 64)}; + MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults); + if (mlirFunctionTypeGetNumInputs(funcType) != 2) + return 21; + if (mlirFunctionTypeGetNumResults(funcType) != 3) + return 22; + if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) || + !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1))) + return 23; + if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) || + !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) || + !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2))) + return 24; + mlirTypeDump(funcType); + fprintf(stderr, "\n"); + return 0; } @@ -691,8 +711,7 @@ return 2; if (!mlirAffineMapIsEmpty(emptyAffineMap) || - mlirAffineMapIsEmpty(affineMap) || - mlirAffineMapIsEmpty(constAffineMap) || + mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) || mlirAffineMapIsEmpty(multiDimIdentityAffineMap) || mlirAffineMapIsEmpty(minorIdentityAffineMap) || mlirAffineMapIsEmpty(permutationAffineMap)) @@ -859,6 +878,7 @@ // CHECK: memref<2x3xf32, 2> // CHECK: memref<*xf32, 4> // CHECK: tuple, f32> + // CHECK: (index, i1) -> (i16, i32, i64) // CHECK: 0 // clang-format on fprintf(stderr, "@types\n");