diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h --- a/mlir/include/mlir-c/AffineExpr.h +++ b/mlir/include/mlir-c/AffineExpr.h @@ -39,6 +39,8 @@ #undef DEFINE_C_API_STRUCT +struct MlirAffineMap; + /// Gets the context that owns the affine expression. MLIR_CAPI_EXPORTED MlirContext mlirAffineExprGetContext(MlirAffineExpr affineExpr); @@ -86,6 +88,10 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr, intptr_t position); +/// Composes the given map with the given expression. +MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose( + MlirAffineExpr affineExpr, struct MlirAffineMap affineMap); + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -205,6 +205,18 @@ return PyAffineAddExpr(lhs.getContext(), expr); } + static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineAddExpr(lhs.getContext(), expr); + } + + static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineAddExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineAddExpr::get); } @@ -222,6 +234,18 @@ return PyAffineMulExpr(lhs.getContext(), expr); } + static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineMulExpr(lhs.getContext(), expr); + } + + static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineMulExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineMulExpr::get); } @@ -239,6 +263,18 @@ return PyAffineModExpr(lhs.getContext(), expr); } + static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineModExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineModExpr(lhs.getContext(), expr); + } + + static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineModExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineModExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineModExpr::get); } @@ -256,6 +292,18 @@ return PyAffineFloorDivExpr(lhs.getContext(), expr); } + static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineFloorDivExpr(lhs.getContext(), expr); + } + + static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineFloorDivExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineFloorDivExpr::get); } @@ -273,6 +321,18 @@ return PyAffineCeilDivExpr(lhs.getContext(), expr); } + static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineCeilDivExpr(lhs.getContext(), expr); + } + + static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineCeilDivExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineCeilDivExpr::get); } @@ -435,17 +495,19 @@ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) - .def("__add__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineAddExpr::get(self, other); - }) - .def("__mul__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineMulExpr::get(self, other); - }) - .def("__mod__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineModExpr::get(self, other); + .def("__add__", &PyAffineAddExpr::get) + .def("__add__", &PyAffineAddExpr::getRHSConstant) + .def("__radd__", &PyAffineAddExpr::getRHSConstant) + .def("__mul__", &PyAffineMulExpr::get) + .def("__mul__", &PyAffineMulExpr::getRHSConstant) + .def("__rmul__", &PyAffineMulExpr::getRHSConstant) + .def("__mod__", &PyAffineModExpr::get) + .def("__mod__", &PyAffineModExpr::getRHSConstant) + .def("__rmod__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineModExpr::get( + PyAffineConstantExpr::get(other, *self.getContext().get()), + self); }) .def("__sub__", [](PyAffineExpr &self, PyAffineExpr &other) { @@ -454,6 +516,17 @@ return PyAffineAddExpr::get(self, PyAffineMulExpr::get(negOne, other)); }) + .def("__sub__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineAddExpr::get( + self, + PyAffineConstantExpr::get(-other, *self.getContext().get())); + }) + .def("__rsub__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineAddExpr::getLHSConstant( + other, PyAffineMulExpr::getLHSConstant(-1, self)); + }) .def("__eq__", [](PyAffineExpr &self, PyAffineExpr &other) { return self == other; }) .def("__eq__", @@ -474,24 +547,63 @@ printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyAffineExpr &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_property_readonly( "context", [](PyAffineExpr &self) { return self.getContext().getObject(); }) + .def("compose", + [](PyAffineExpr &self, PyAffineMap &other) { + return PyAffineExpr(self.getContext(), + mlirAffineExprCompose(self, other)); + }) .def_static( "get_add", &PyAffineAddExpr::get, "Gets an affine expression containing a sum of two expressions.") + .def_static("get_add", &PyAffineAddExpr::getLHSConstant, + "Gets an affine expression containing a sum of a constant " + "and another expression.") + .def_static("get_add", &PyAffineAddExpr::getRHSConstant, + "Gets an affine expression containing a sum of an expression " + "and a constant.") .def_static( "get_mul", &PyAffineMulExpr::get, "Gets an affine expression containing a product of two expressions.") + .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, + "Gets an affine expression containing a product of a " + "constant and another expression.") + .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, + "Gets an affine expression containing a product of an " + "expression and a constant.") .def_static("get_mod", &PyAffineModExpr::get, "Gets an affine expression containing the modulo of dividing " "one expression by another.") + .def_static("get_mod", &PyAffineModExpr::getLHSConstant, + "Gets a semi-affine expression containing the modulo of " + "dividing a constant by an expression.") + .def_static("get_mod", &PyAffineModExpr::getRHSConstant, + "Gets an affine expression containing the module of dividing" + "an expression by a constant.") .def_static("get_floor_div", &PyAffineFloorDivExpr::get, "Gets an affine expression containing the rounded-down " "result of dividing one expression by another.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, + "Gets a semi-affine expression containing the rounded-down " + "result of dividing a constant by an expression.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, + "Gets an affine expression containing the rounded-down " + "result of dividing an expression by a constant.") .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, "Gets an affine expression containing the rounded-up result " "of dividing one expression by another.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, + "Gets a semi-affine expression containing the rounded-up " + "result of dividing a constant by an expression.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, + "Gets an affine expression containing the rounded-up result " + "of dividing an expression by a constant.") .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), py::arg("context") = py::none(), "Gets a constant affine expression with the given value.") @@ -542,6 +654,10 @@ printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyAffineMap &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_static("compress_unused_symbols", [](py::list affineMaps, DefaultingPyMlirContext context) { SmallVector maps; @@ -714,6 +830,10 @@ printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyIntegerSet &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_property_readonly( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -56,6 +56,11 @@ return unwrap(affineExpr).isFunctionOfDim(position); } +MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr, + MlirAffineMap affineMap) { + return wrap(unwrap(affineExpr).compose(unwrap(affineMap))); +} + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1393,6 +1393,13 @@ if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr)) return 3; + MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, 1); + MlirAffineExpr composed = mlirAffineExprCompose(affineDim2Expr, map); + // CHECK: s1 + mlirAffineExprDump(composed); + if (!mlirAffineExprEqual(composed, affineSymbolExpr)) + return 4; + return 0; } diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py --- a/mlir/test/python/ir/affine_expr.py +++ b/mlir/test/python/ir/affine_expr.py @@ -137,6 +137,14 @@ # CHECK: d1 + d2 print(d12op) + d1cst_op = d1 + 2 + # CHECK: d1 + 2 + print(d1cst_op) + + d1cst_op2 = 2 + d1 + # CHECK: d1 + 2 + print(d1cst_op2) + assert d12 == d12op assert d12.lhs == d1 assert d12.rhs == d2 @@ -156,7 +164,16 @@ op = d1 * c2 print(op) + # CHECK: d1 * 2 + op_cst = d1 * 2 + print(op_cst) + + # CHECK: d1 * 2 + op_cst2 = 2 * d1 + print(op_cst2) + assert expr == op + assert expr == op_cst assert expr.lhs == d1 assert expr.rhs == c2 @@ -175,10 +192,32 @@ op = d1 % c2 print(op) + # CHECK: d1 mod 2 + op_cst = d1 % 2 + print(op_cst) + + # CHECK: 2 mod d1 + print(2 % d1) + assert expr == op + assert expr == op_cst assert expr.lhs == d1 assert expr.rhs == c2 + expr2 = AffineExpr.get_mod(c2, d1) + expr3 = AffineExpr.get_mod(2, d1) + expr4 = AffineExpr.get_mod(d1, 2) + + # CHECK: 2 mod d1 + print(expr2) + # CHECK: 2 mod d1 + print(expr3) + # CHECK: d1 mod 2 + print(expr4) + + assert expr2 == expr3 + assert expr4 == expr + # CHECK-LABEL: TEST: testAffineFloorDivExpr @run @@ -193,6 +232,20 @@ assert expr.lhs == d1 assert expr.rhs == c2 + expr2 = AffineExpr.get_floor_div(c2, d1) + expr3 = AffineExpr.get_floor_div(2, d1) + expr4 = AffineExpr.get_floor_div(d1, 2) + + # CHECK: 2 floordiv d1 + print(expr2) + # CHECK: 2 floordiv d1 + print(expr3) + # CHECK: d1 floordiv 2 + print(expr4) + + assert expr2 == expr3 + assert expr4 == expr + # CHECK-LABEL: TEST: testAffineCeilDivExpr @run @@ -207,6 +260,20 @@ assert expr.lhs == d1 assert expr.rhs == c2 + expr2 = AffineExpr.get_ceil_div(c2, d1) + expr3 = AffineExpr.get_ceil_div(2, d1) + expr4 = AffineExpr.get_ceil_div(d1, 2) + + # CHECK: 2 ceildiv d1 + print(expr2) + # CHECK: 2 ceildiv d1 + print(expr3) + # CHECK: d1 ceildiv 2 + print(expr4) + + assert expr2 == expr3 + assert expr4 == expr + # CHECK-LABEL: TEST: testAffineExprSub @run @@ -225,6 +292,15 @@ # CHECK: -1 print(rhs.rhs) + # CHECK: d1 - 42 + print(d1 - 42) + # CHECK: -d1 + 42 + print(42 - d1) + + c42 = AffineConstantExpr.get(42) + assert d1 - 42 == d1 - c42 + assert 42 - d1 == c42 - d1 + # CHECK-LABEL: TEST: testClassHierarchy @run def testClassHierarchy(): @@ -289,3 +365,38 @@ print(AffineMulExpr.isinstance(mul)) # CHECK: False print(AffineAddExpr.isinstance(mul)) + + +# CHECK-LABEL: TEST: testCompose +@run +def testCompose(): + with Context(): + # d0 + d2. + expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2)) + + # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2) + map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1)) + map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0)) + map3 = AffineAddExpr.get( + AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)), + AffineDimExpr.get(2)) + map = AffineMap.get(3, 2, [map1, map2, map3]) + + # CHECK: d0 + s1 + d0 + d1 + d2 + print(expr.compose(map)) + + +# CHECK-LABEL: TEST: testHash +@run +def testHash(): + with Context(): + d0 = AffineDimExpr.get(0) + s1 = AffineSymbolExpr.get(1) + assert hash(d0) == hash(AffineDimExpr.get(0)) + assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1)) + + dictionary = dict() + dictionary[d0] = 0 + dictionary[s1] = 1 + assert d0 in dictionary + assert s1 in dictionary diff --git a/mlir/test/python/ir/affine_map.py b/mlir/test/python/ir/affine_map.py --- a/mlir/test/python/ir/affine_map.py +++ b/mlir/test/python/ir/affine_map.py @@ -9,9 +9,11 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testAffineMapCapsule +@run def testAffineMapCapsule(): with Context() as ctx: am1 = AffineMap.get_empty(ctx) @@ -23,10 +25,8 @@ assert am2.context is ctx -run(testAffineMapCapsule) - - # CHECK-LABEL: TEST: testAffineMapGet +@run def testAffineMapGet(): with Context() as ctx: d0 = AffineDimExpr.get(0) @@ -100,10 +100,8 @@ print(e) -run(testAffineMapGet) - - # CHECK-LABEL: TEST: testAffineMapDerive +@run def testAffineMapDerive(): with Context() as ctx: map5 = AffineMap.get_identity(5) @@ -121,10 +119,8 @@ print(map34) -run(testAffineMapDerive) - - # CHECK-LABEL: TEST: testAffineMapProperties +@run def testAffineMapProperties(): with Context(): d0 = AffineDimExpr.get(0) @@ -147,10 +143,8 @@ print(map3.is_projected_permutation) -run(testAffineMapProperties) - - # CHECK-LABEL: TEST: testAffineMapExprs +@run def testAffineMapExprs(): with Context(): d0 = AffineDimExpr.get(0) @@ -181,10 +175,8 @@ assert list(map3.results) == [d2, d0, d1] -run(testAffineMapExprs) - - # CHECK-LABEL: TEST: testCompressUnusedSymbols +@run def testCompressUnusedSymbols(): with Context() as ctx: d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), @@ -210,10 +202,8 @@ print(compressed_maps) -run(testCompressUnusedSymbols) - - # CHECK-LABEL: TEST: testReplace +@run def testReplace(): with Context() as ctx: d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), @@ -236,4 +226,16 @@ print(replace3) -run(testReplace) +# CHECK-LABEL: TEST: testHash +@run +def testHash(): + with Context(): + d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1) + m1 = AffineMap.get(2, 0, [d0, d1]) + m2 = AffineMap.get(2, 0, [d1, d0]) + assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1])) + + dictionary = dict() + dictionary[m1] = 1 + dictionary[m2] = 2 + assert m1 in dictionary diff --git a/mlir/test/python/ir/integer_set.py b/mlir/test/python/ir/integer_set.py --- a/mlir/test/python/ir/integer_set.py +++ b/mlir/test/python/ir/integer_set.py @@ -8,9 +8,11 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testIntegerSetCapsule +@run def testIntegerSetCapsule(): with Context() as ctx: is1 = IntegerSet.get_empty(1, 1, ctx) @@ -21,10 +23,9 @@ assert is1 == is2 assert is2.context is ctx -run(testIntegerSetCapsule) - # CHECK-LABEL: TEST: testIntegerSetGet +@run def testIntegerSetGet(): with Context(): d0 = AffineDimExpr.get(0) @@ -92,10 +93,9 @@ # CHECK: Invalid expression (None?) when attempting to create an IntegerSet by replacing symbols print(e) -run(testIntegerSetGet) - # CHECK-LABEL: TEST: testIntegerSetProperties +@run def testIntegerSetProperties(): with Context(): d0 = AffineDimExpr.get(0) @@ -125,4 +125,17 @@ print(cstr.expr, end='') print(" == 0" if cstr.is_eq else " >= 0") -run(testIntegerSetProperties) + +# CHECK_LABEL: TEST: testHash +@run +def testHash(): + with Context(): + d0 = AffineDimExpr.get(0) + d1 = AffineDimExpr.get(1) + set = IntegerSet.get(2, 0, [d0 + d1], [True]) + + assert hash(set) == hash(IntegerSet.get(2, 0, [d0 + d1], [True])) + + dictionary = dict() + dictionary[set] = 42 + assert set in dictionary