diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h --- a/mlir/include/mlir-c/AffineExpr.h +++ b/mlir/include/mlir-c/AffineExpr.h @@ -10,7 +10,6 @@ #ifndef MLIR_C_AFFINEEXPR_H #define MLIR_C_AFFINEEXPR_H -#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h --- a/mlir/include/mlir-c/AffineMap.h +++ b/mlir/include/mlir-c/AffineMap.h @@ -10,6 +10,7 @@ #ifndef MLIR_C_AFFINEMAP_H #define MLIR_C_AFFINEMAP_H +#include "mlir-c/AffineExpr.h" #include "mlir-c/IR.h" #ifdef __cplusplus @@ -67,9 +68,18 @@ /** Creates a zero result affine map of the given dimensions and symbols in the * context. The affine map is owned by the context. */ +MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapZeroResultGet( + MlirContext ctx, intptr_t dimCount, intptr_t symbolCount); + +/** Creates an affine map with results defined by the given list of affine + * expressions. The map resulting map also has the requested number of input + * dimensions and symbols, regardless of them being used in the results. + */ MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount, - intptr_t symbolCount); + intptr_t symbolCount, + intptr_t nAffineExprs, + MlirAffineExpr *affineExprs); /** Creates a single constant result affine map in the context. The affine map * is owned by the context. */ @@ -124,6 +134,10 @@ /// Returns the number of results of the given affine map. MLIR_CAPI_EXPORTED intptr_t mlirAffineMapGetNumResults(MlirAffineMap affineMap); +/// Returns the result at the given position. +MLIR_CAPI_EXPORTED MlirAffineExpr +mlirAffineMapGetResult(MlirAffineMap affineMap, intptr_t pos); + /** Returns the number of inputs (dimensions + symbols) of the given affine * map. */ MLIR_CAPI_EXPORTED intptr_t mlirAffineMapGetNumInputs(MlirAffineMap affineMap); diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -11,6 +11,7 @@ #include "Globals.h" #include "PybindUtils.h" +#include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" @@ -2943,9 +2944,43 @@ } //------------------------------------------------------------------------------ -// PyAffineMap. +// PyAffineMap and utilities. //------------------------------------------------------------------------------ +namespace { +/// A list of expressions contained in an affine map. Internally these are +/// stored as a consecutive array leading to inexpensive random access. Both +/// the map and the expression are owned by the context so we need not bother +/// with lifetime extension. +class PyAffineMapExprList + : public Sliceable { +public: + static constexpr const char *pyClassName = "AffineExprList"; + + PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirAffineMapGetNumResults(map) : length, + step), + affineMap(map) {} + + intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } + + PyAffineExpr getElement(intptr_t pos) { + return PyAffineExpr(affineMap.getContext(), + mlirAffineMapGetResult(affineMap, pos)); + } + + PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyAffineMapExprList(affineMap, startIndex, length, step); + } + +private: + PyAffineMap affineMap; +}; +} // end namespace + bool PyAffineMap::operator==(const PyAffineMap &other) { return mlirAffineMapEqual(affineMap, other.affineMap); } @@ -3741,6 +3776,72 @@ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) + .def("__eq__", + [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) + .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) + .def("__str__", + [](PyAffineMap &self) { + PyPrintAccumulator printAccum; + mlirAffineMapPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyAffineMap &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("AffineMap("); + mlirAffineMapPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyAffineMap &self) { return self.getContext().getObject(); }, + "Context that owns the Affine Map") + .def( + "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, + kDumpDocstring) + .def_static( + "get", + [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, + DefaultingPyMlirContext context) { + SmallVector affineExprs; + affineExprs.reserve(py::len(exprs)); + for (py::handle expr : exprs) { + try { + affineExprs.push_back(expr.cast()); + } catch (py::cast_error &err) { + std::string msg = + std::string("Invalid expression when attempting to create " + "an AffineMap (") + + err.what() + ")"; + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + std::string msg = + std::string("Invalid expression (None?) when attempting to " + "create an AffineMap (") + + err.what() + ")"; + throw py::cast_error(msg); + } + } + MlirAffineMap map = + mlirAffineMapGet(context->get(), dimCount, symbolCount, + affineExprs.size(), affineExprs.data()); + return PyAffineMap(context->getRef(), map); + }, + py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), + py::arg("context") = py::none(), + "Gets a map with the given expressions as results.") + .def_static( + "get_constant", + [](intptr_t value, DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapConstantGet(context->get(), value); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an affine map with a single constant result") .def_static( "get_empty", [](DefaultingPyMlirContext context) { @@ -3748,14 +3849,82 @@ return PyAffineMap(context->getRef(), affineMap); }, py::arg("context") = py::none(), "Gets an empty affine map.") + .def_static( + "get_identity", + [](intptr_t nDims, DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapMultiDimIdentityGet(context->get(), nDims); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("n_dims"), py::arg("context") = py::none(), + "Gets an identity map with the given number of dimensions.") + .def_static( + "get_minor_identity", + [](intptr_t nDims, intptr_t nResults, + DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("n_dims"), py::arg("n_results"), + py::arg("context") = py::none(), + "Gets a minor identity map with the given number of dimensions and " + "results.") + .def_static( + "get_permutation", + [](std::vector permutation, + DefaultingPyMlirContext context) { + MlirAffineMap affineMap = mlirAffineMapPermutationGet( + context->get(), permutation.size(), permutation.data()); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("permutation"), py::arg("context") = py::none(), + "Gets an affine map that permutes its inputs.") + .def("get_submap", + [](PyAffineMap &self, std::vector &resultPos) { + intptr_t numResults = mlirAffineMapGetNumResults(self); + for (intptr_t pos : resultPos) { + if (pos < 0 || pos >= numResults) + throw py::value_error("result position out of bounds"); + } + MlirAffineMap affineMap = mlirAffineMapGetSubMap( + self, resultPos.size(), resultPos.data()); + return PyAffineMap(self.getContext(), affineMap); + }) + .def("get_major_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw py::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMajorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }) + .def("get_minor_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw py::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMinorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }) .def_property_readonly( - "context", - [](PyAffineMap &self) { return self.getContext().getObject(); }, - "Context that owns the Affine Map") - .def("__eq__", - [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) - .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) - .def( - "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, - kDumpDocstring); + "is_permutation", + [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) + .def_property_readonly("is_projected_permutation", + [](PyAffineMap &self) { + return mlirAffineMapIsProjectedPermutation(self); + }) + .def_property_readonly( + "n_dims", + [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) + .def_property_readonly( + "n_inputs", + [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) + .def_property_readonly( + "n_symbols", + [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) + .def_property_readonly("results", [](PyAffineMap &self) { + return PyAffineMapExprList(self); + }); + PyAffineMapExprList::bind(m); } diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp --- a/mlir/lib/CAPI/IR/AffineMap.cpp +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -8,6 +8,7 @@ #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir/CAPI/AffineExpr.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Utils.h" @@ -37,11 +38,19 @@ return wrap(AffineMap::get(unwrap(ctx))); } -MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount, - intptr_t symbolCount) { +MlirAffineMap mlirAffineMapZeroResultGet(MlirContext ctx, intptr_t dimCount, + intptr_t symbolCount) { return wrap(AffineMap::get(dimCount, symbolCount, unwrap(ctx))); } +MlirAffineMap mlirAffineMapGet(MlirContext ctx, intptr_t dimCount, + intptr_t symbolCount, intptr_t nAffineExprs, + MlirAffineExpr *affineExprs) { + SmallVector exprs; + ArrayRef exprList = unwrapList(nAffineExprs, affineExprs, exprs); + return wrap(AffineMap::get(dimCount, symbolCount, exprList, unwrap(ctx))); +} + MlirAffineMap mlirAffineMapConstantGet(MlirContext ctx, int64_t val) { return wrap(AffineMap::getConstantMap(val, unwrap(ctx))); } @@ -94,6 +103,10 @@ return unwrap(affineMap).getNumResults(); } +MlirAffineExpr mlirAffineMapGetResult(MlirAffineMap affineMap, intptr_t pos) { + return wrap(unwrap(affineMap).getResult(static_cast(pos))); +} + intptr_t mlirAffineMapGetNumInputs(MlirAffineMap affineMap) { return unwrap(affineMap).getNumInputs(); } diff --git a/mlir/test/Bindings/Python/ir_affine_map.py b/mlir/test/Bindings/Python/ir_affine_map.py --- a/mlir/test/Bindings/Python/ir_affine_map.py +++ b/mlir/test/Bindings/Python/ir_affine_map.py @@ -22,3 +22,151 @@ assert am2.context is ctx run(testAffineMapCapsule) + + +# CHECK-LABEL: TEST: testAffineMapGet +def testAffineMapGet(): + with Context() as ctx: + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + + # CHECK: (d0, d1)[s0, s1, s2] -> () + map0 = AffineMap.get(2, 3, []) + print(map0) + + # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2) + map1 = AffineMap.get(2, 3, [d1, c2]) + print(map1) + + # CHECK: () -> (2) + map2 = AffineMap.get(0, 0, [c2]) + print(map2) + + # CHECK: (d0, d1) -> (d0, d1) + map3 = AffineMap.get(2, 0, [d0, d1]) + print(map3) + + # CHECK: (d0, d1) -> (d1) + map4 = AffineMap.get(2, 0, [d1]) + print(map4) + + # CHECK: (d0, d1, d2) -> (d2, d0, d1) + map5 = AffineMap.get_permutation([2, 0, 1]) + print(map5) + + assert map1 == AffineMap.get(2, 3, [d1, c2]) + assert AffineMap.get(0, 0, []) == AffineMap.get_empty() + assert map2 == AffineMap.get_constant(2) + assert map3 == AffineMap.get_identity(2) + assert map4 == AffineMap.get_minor_identity(2, 1) + + try: + AffineMap.get(1, 1, [1]) + except RuntimeError as e: + # CHECK: Invalid expression when attempting to create an AffineMap + print(e) + + try: + AffineMap.get(1, 1, [None]) + except RuntimeError as e: + # CHECK: Invalid expression (None?) when attempting to create an AffineMap + print(e) + + try: + map3.get_submap([42]) + except ValueError as e: + # CHECK: result position out of bounds + print(e) + + try: + map3.get_minor_submap(42) + except ValueError as e: + # CHECK: number of results out of bounds + print(e) + + try: + map3.get_major_submap(42) + except ValueError as e: + # CHECK: number of results out of bounds + print(e) + +run(testAffineMapGet) + + +# CHECK-LABEL: TEST: testAffineMapDerive +def testAffineMapDerive(): + with Context() as ctx: + map5 = AffineMap.get_identity(5) + + # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3) + map123 = map5.get_submap([1,2,3]) + print(map123) + + # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1) + map01 = map5.get_major_submap(2) + print(map01) + + # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4) + map34 = map5.get_minor_submap(2) + print(map34) + +run(testAffineMapDerive) + + +# CHECK-LABEL: TEST: testAffineMapProperties +def testAffineMapProperties(): + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + map1 = AffineMap.get(3, 0, [d2, d0]) + map2 = AffineMap.get(3, 0, [d2, d0, d1]) + map3 = AffineMap.get(3, 1, [d2, d0, d1]) + # CHECK: False + print(map1.is_permutation) + # CHECK: True + print(map1.is_projected_permutation) + # CHECK: True + print(map2.is_permutation) + # CHECK: True + print(map2.is_projected_permutation) + # CHECK: False + print(map3.is_permutation) + # CHECK: False + print(map3.is_projected_permutation) + +run(testAffineMapProperties) + + +# CHECK-LABEL: TEST: testAffineMapExprs +def testAffineMapExprs(): + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + map3 = AffineMap.get(3, 1, [d2, d0, d1]) + + # CHECK: 3 + print(map3.n_dims) + # CHECK: 4 + print(map3.n_inputs) + # CHECK: 1 + print(map3.n_symbols) + assert map3.n_inputs == map3.n_dims + map3.n_symbols + + # CHECK: 3 + print(len(map3.results)) + for expr in map3.results: + # CHECK: d2 + # CHECK: d0 + # CHECK: d1 + print(expr) + for expr in map3.results[-1:-4:-1]: + # CHECK: d1 + # CHECK: d0 + # CHECK: d2 + print(expr) + assert list(map3.results) == [d2, d0, d1] + +run(testAffineMapExprs) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1007,7 +1007,7 @@ int printAffineMap(MlirContext ctx) { MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx); - MlirAffineMap affineMap = mlirAffineMapGet(ctx, 3, 2); + MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2); MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2); MlirAffineMap multiDimIdentityAffineMap = mlirAffineMapMultiDimIdentityGet(ctx, 3); @@ -1275,6 +1275,29 @@ return 0; } +int affineMapFromExprs(MlirContext ctx) { + MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0); + MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1); + MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr}; + MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs); + + // CHECK-LABEL: @affineMapFromExprs + fprintf(stderr, "@affineMapFromExprs"); + // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1) + mlirAffineMapDump(map); + + if (mlirAffineMapGetNumResults(map) != 2) + return 1; + + if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr)) + return 2; + + if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr)) + return 3; + + return 0; +} + int registerOnlyStd() { MlirContext ctx = mlirContextCreate(); // The built-in dialect is always loaded. @@ -1375,8 +1398,10 @@ return 4; if (printAffineExpr(ctx)) return 5; - if (registerOnlyStd()) + if (affineMapFromExprs(ctx)) return 6; + if (registerOnlyStd()) + return 7; mlirContextDestroy(ctx);