diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h --- a/mlir/include/mlir-c/AffineExpr.h +++ b/mlir/include/mlir-c/AffineExpr.h @@ -45,6 +45,16 @@ MLIR_CAPI_EXPORTED MlirContext mlirAffineExprGetContext(MlirAffineExpr affineExpr); +/// Returns `true` if the two affine expressions are equal. +MLIR_CAPI_EXPORTED bool mlirAffineExprEqual(MlirAffineExpr lhs, + MlirAffineExpr rhs); + +/// Returns `true` if the given affine expression is a null expression. Note +/// constant zero is not a null expression. +inline bool mlirAffineExprIsNull(MlirAffineExpr affineExpr) { + return affineExpr.ptr == NULL; +} + /** Prints an affine expression by sending chunks of the string representation * and forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ @@ -82,6 +92,9 @@ // Affine Dimension Expression. //===----------------------------------------------------------------------===// +/// Checks whether the given affine expression is a dimension expression. +MLIR_CAPI_EXPORTED bool mlirAffineExprIsADim(MlirAffineExpr affineExpr); + /// Creates an affine dimension expression with 'position' in the context. MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position); @@ -94,6 +107,9 @@ // Affine Symbol Expression. //===----------------------------------------------------------------------===// +/// Checks whether the given affine expression is a symbol expression. +MLIR_CAPI_EXPORTED bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr); + /// Creates an affine symbol expression with 'position' in the context. MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position); @@ -106,6 +122,9 @@ // Affine Constant Expression. //===----------------------------------------------------------------------===// +/// Checks whether the given affine expression is a constant expression. +MLIR_CAPI_EXPORTED bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr); + /// Creates an affine constant expression with 'constant' in the context. MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant); @@ -173,6 +192,9 @@ // Affine Binary Operation Expression. //===----------------------------------------------------------------------===// +/// Checks whether the given affine expression is binary. +MLIR_CAPI_EXPORTED bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr); + /** Returns the left hand side affine expression of the given affine binary * operation expression. */ MLIR_CAPI_EXPORTED MlirAffineExpr diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -23,10 +23,12 @@ #include +#include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/Pass.h" +#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr" #define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr" #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" @@ -72,6 +74,25 @@ extern "C" { #endif +/** Creates a capsule object encapsulating the raw C-API MlirAffineExpr. The + * returned capsule does not extend or affect ownership of any Python objects + * that reference the expression in any way. + */ +static inline PyObject *mlirPythonAffineExprToCapsule(MlirAffineExpr expr) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(expr), + MLIR_PYTHON_CAPSULE_AFFINE_EXPR, NULL); +} + +/** Extracts an MlirAffineExpr from a capsule as produced from + * mlirPythonAffineExprToCapsule. If the capsule is not of the right type, then + * a null expression is returned (as checked via mlirAffineExprIsNull). In such + * a case, the Python APIs will have already set an error. */ +static inline MlirAffineExpr mlirPythonCapsuleToAffineExpr(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_AFFINE_EXPR); + MlirAffineExpr expr = {ptr}; + return expr; +} + /** Creates a capsule object encapsulating the raw C-API MlirAttribute. * The returned capsule does not extend or affect ownership of any Python * objects that reference the attribute in any way. diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -13,6 +13,7 @@ #include "PybindUtils.h" +#include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "llvm/ADT/DenseMap.h" @@ -668,6 +669,34 @@ MlirValue value; }; +/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. +class PyAffineExpr : public BaseContextObject { +public: + PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) + : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} + bool operator==(const PyAffineExpr &other); + operator MlirAffineExpr() const { return affineExpr; } + MlirAffineExpr get() const { return affineExpr; } + + /// Gets a capsule wrapping the void* within the MlirAffineExpr. + pybind11::object getCapsule(); + + /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. + /// Note that PyAffineExpr instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr + /// is taken by calling this function. + static PyAffineExpr createFromCapsule(pybind11::object capsule); + + PyAffineExpr add(const PyAffineExpr &other) const; + PyAffineExpr mul(const PyAffineExpr &other) const; + PyAffineExpr floorDiv(const PyAffineExpr &other) const; + PyAffineExpr ceilDiv(const PyAffineExpr &other) const; + PyAffineExpr mod(const PyAffineExpr &other) const; + +private: + MlirAffineExpr affineExpr; +}; + class PyAffineMap : public BaseContextObject { public: PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2710,6 +2710,238 @@ } // namespace +//------------------------------------------------------------------------------ +// PyAffineExpr and subclasses. +//------------------------------------------------------------------------------ + +namespace { +/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr +/// and should be castable from it. Intermediate hierarchy classes can be +/// modeled by specifying BaseTy. +template +class PyConcreteAffineExpr : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = py::class_; + using IsAFunctionTy = bool (*)(MlirAffineExpr); + + PyConcreteAffineExpr() = default; + PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) + : BaseTy(std::move(contextRef), affineExpr) {} + PyConcreteAffineExpr(PyAffineExpr &orig) + : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} + + static MlirAffineExpr castFrom(PyAffineExpr &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, + Twine("Cannot cast affine expression to ") + + DerivedTy::pyClassName + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(py::init()); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyAffineConstantExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; + static constexpr const char *pyClassName = "AffineConstantExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineConstantExpr get(intptr_t value, + DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = + mlirAffineConstantExprGet(context->get(), static_cast(value)); + return PyAffineConstantExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), + py::arg("context") = py::none()); + c.def_property_readonly("value", [](PyAffineConstantExpr &self) { + return mlirAffineConstantExprGetValue(self); + }); + } +}; + +class PyAffineDimExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; + static constexpr const char *pyClassName = "AffineDimExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); + return PyAffineDimExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), + py::arg("context") = py::none()); + c.def_property_readonly("position", [](PyAffineDimExpr &self) { + return mlirAffineDimExprGetPosition(self); + }); + } +}; + +class PyAffineSymbolExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; + static constexpr const char *pyClassName = "AffineSymbolExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); + return PyAffineSymbolExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), + py::arg("context") = py::none()); + c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { + return mlirAffineSymbolExprGetPosition(self); + }); + } +}; + +class PyAffineBinaryExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; + static constexpr const char *pyClassName = "AffineBinaryExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + PyAffineExpr lhs() { + MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); + return PyAffineExpr(getContext(), lhsExpr); + } + + PyAffineExpr rhs() { + MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); + return PyAffineExpr(getContext(), rhsExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); + c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); + } +}; + +class PyAffineAddExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; + static constexpr const char *pyClassName = "AffineAddExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); + return PyAffineAddExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineAddExpr::get); + } +}; + +class PyAffineMulExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; + static constexpr const char *pyClassName = "AffineMulExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); + return PyAffineMulExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineMulExpr::get); + } +}; + +class PyAffineModExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; + static constexpr const char *pyClassName = "AffineModExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); + return PyAffineModExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineModExpr::get); + } +}; + +class PyAffineFloorDivExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; + static constexpr const char *pyClassName = "AffineFloorDivExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); + return PyAffineFloorDivExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineFloorDivExpr::get); + } +}; + +class PyAffineCeilDivExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; + static constexpr const char *pyClassName = "AffineCeilDivExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); + return PyAffineCeilDivExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineCeilDivExpr::get); + } +}; +} // namespace + +bool PyAffineExpr::operator==(const PyAffineExpr &other) { + return mlirAffineExprEqual(affineExpr, other.affineExpr); +} + +py::object PyAffineExpr::getCapsule() { + return py::reinterpret_steal( + mlirPythonAffineExprToCapsule(*this)); +} + +PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { + MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); + if (mlirAffineExprIsNull(rawAffineExpr)) + throw py::error_already_set(); + return PyAffineExpr( + PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), + rawAffineExpr); +} + //------------------------------------------------------------------------------ // PyAffineMap. //------------------------------------------------------------------------------ @@ -3414,6 +3646,94 @@ PyRegionIterator::bind(m); PyRegionList::bind(m); + //---------------------------------------------------------------------------- + // Mapping of PyAffineExpr and derived classes. + //---------------------------------------------------------------------------- + py::class_(m, "AffineExpr") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAffineExpr::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) + .def("__add__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineAddExpr::get(self, other); + }) + .def("__mul__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineMulExpr::get(self, other); + }) + .def("__mod__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineModExpr::get(self, other); + }) + .def("__sub__", + [](PyAffineExpr &self, PyAffineExpr &other) { + auto negOne = + PyAffineConstantExpr::get(-1, *self.getContext().get()); + return PyAffineAddExpr::get(self, + PyAffineMulExpr::get(negOne, other)); + }) + .def("__eq__", [](PyAffineExpr &self, + PyAffineExpr &other) { return self == other; }) + .def("__eq__", + [](PyAffineExpr &self, py::object &other) { return false; }) + .def("__str__", + [](PyAffineExpr &self) { + PyPrintAccumulator printAccum; + mlirAffineExprPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyAffineExpr &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("AffineExpr("); + mlirAffineExprPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyAffineExpr &self) { return self.getContext().getObject(); }) + .def_static( + "get_add", &PyAffineAddExpr::get, + "Gets an affine expression containing a sum of two expressions.") + .def_static( + "get_mul", &PyAffineMulExpr::get, + "Gets an affine expression containing a product of two expressions.") + .def_static("get_mod", &PyAffineModExpr::get, + "Gets an affine expression containing the modulo of dividing " + "one expression by another.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::get, + "Gets an affine expression containing the rounded-down " + "result of dividing one expression by another.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, + "Gets an affine expression containing the rounded-up result " + "of dividing one expression by another.") + .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), + py::arg("context") = py::none(), + "Gets a constant affine expression with the given value.") + .def_static( + "get_dim", &PyAffineDimExpr::get, py::arg("position"), + py::arg("context") = py::none(), + "Gets an affine expression of a dimension at the given position.") + .def_static( + "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), + py::arg("context") = py::none(), + "Gets an affine expression of a symbol at the given position.") + .def( + "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, + kDumpDocstring); + PyAffineConstantExpr::bind(m); + PyAffineDimExpr::bind(m); + PyAffineSymbolExpr::bind(m); + PyAffineBinaryExpr::bind(m); + PyAffineAddExpr::bind(m); + PyAffineMulExpr::bind(m); + PyAffineModExpr::bind(m); + PyAffineFloorDivExpr::bind(m); + PyAffineCeilDivExpr::bind(m); + //---------------------------------------------------------------------------- // Mapping of PyAffineMap. //---------------------------------------------------------------------------- diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -21,6 +21,10 @@ return wrap(unwrap(affineExpr).getContext()); } +bool mlirAffineExprEqual(MlirAffineExpr lhs, MlirAffineExpr rhs) { + return unwrap(lhs) == unwrap(rhs); +} + void mlirAffineExprPrint(MlirAffineExpr affineExpr, MlirStringCallback callback, void *userData) { mlir::detail::CallbackOstream stream(callback, userData); @@ -56,6 +60,10 @@ // Affine Dimension Expression. //===----------------------------------------------------------------------===// +bool mlirAffineExprIsADim(MlirAffineExpr affineExpr) { + return unwrap(affineExpr).isa(); +} + MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) { return wrap(getAffineDimExpr(position, unwrap(ctx))); } @@ -68,6 +76,10 @@ // Affine Symbol Expression. //===----------------------------------------------------------------------===// +bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr) { + return unwrap(affineExpr).isa(); +} + MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) { return wrap(getAffineSymbolExpr(position, unwrap(ctx))); } @@ -80,6 +92,10 @@ // Affine Constant Expression. //===----------------------------------------------------------------------===// +bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr) { + return unwrap(affineExpr).isa(); +} + MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) { return wrap(getAffineConstantExpr(constant, unwrap(ctx))); } @@ -159,6 +175,10 @@ // Affine Binary Operation Expression. //===----------------------------------------------------------------------===// +bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr) { + return unwrap(affineExpr).isa(); +} + MlirAffineExpr mlirAffineBinaryOpExprGetLHS(MlirAffineExpr affineExpr) { return wrap(unwrap(affineExpr).cast().getLHS()); } diff --git a/mlir/test/Bindings/Python/ir_affine_expr.py b/mlir/test/Bindings/Python/ir_affine_expr.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_affine_expr.py @@ -0,0 +1,275 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +# CHECK-LABEL: TEST: testAffineExprCapsule +def testAffineExprCapsule(): + with Context() as ctx: + affine_expr = AffineExpr.get_constant(42) + + affine_expr_capsule = affine_expr._CAPIPtr + # CHECK: capsule object + # CHECK: mlir.ir.AffineExpr._CAPIPtr + print(affine_expr_capsule) + + affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule) + assert affine_expr == affine_expr_2 + assert affine_expr_2.context == ctx + +run(testAffineExprCapsule) + + +# CHECK-LABEL: TEST: testAffineExprEq +def testAffineExprEq(): + with Context(): + a1 = AffineExpr.get_constant(42) + a2 = AffineExpr.get_constant(42) + a3 = AffineExpr.get_constant(43) + # CHECK: True + print(a1 == a1) + # CHECK: True + print(a1 == a2) + # CHECK: False + print(a1 == a3) + # CHECK: False + print(a1 == None) + # CHECK: False + print(a1 == "foo") + +run(testAffineExprEq) + + +# CHECK-LABEL: TEST: testAffineExprContext +def testAffineExprContext(): + with Context(): + a1 = AffineExpr.get_constant(42) + with Context(): + a2 = AffineExpr.get_constant(42) + + # CHECK: False + print(a1 == a2) + +run(testAffineExprContext) + + +# CHECK-LABEL: TEST: testAffineExprConstant +def testAffineExprConstant(): + with Context(): + a1 = AffineExpr.get_constant(42) + # CHECK: 42 + print(a1.value) + # CHECK: 42 + print(a1) + + a2 = AffineConstantExpr.get(42) + # CHECK: 42 + print(a2.value) + # CHECK: 42 + print(a2) + + assert a1 == a2 + +run(testAffineExprConstant) + + +# CHECK-LABEL: TEST: testAffineExprDim +def testAffineExprDim(): + with Context(): + d1 = AffineExpr.get_dim(1) + d11 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + + # CHECK: 1 + print(d1.position) + # CHECK: d1 + print(d1) + + # CHECK: 2 + print(d2.position) + # CHECK: d2 + print(d2) + + assert d1 == d11 + assert d1 != d2 + +run(testAffineExprDim) + + +# CHECK-LABEL: TEST: testAffineExprSymbol +def testAffineExprSymbol(): + with Context(): + s1 = AffineExpr.get_symbol(1) + s11 = AffineSymbolExpr.get(1) + s2 = AffineSymbolExpr.get(2) + + # CHECK: 1 + print(s1.position) + # CHECK: s1 + print(s1) + + # CHECK: 2 + print(s2.position) + # CHEKC: s2 + print(s2) + + assert s1 == s11 + assert s1 != s2 + +run(testAffineExprSymbol) + + +# CHECK-LABEL: TEST: testAffineAddExpr +def testAffineAddExpr(): + with Context(): + d1 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + d12 = AffineExpr.get_add(d1, d2) + # CHECK: d1 + d2 + print(d12) + + d12op = d1 + d2 + # CHECK: d1 + d2 + print(d12op) + + assert d12 == d12op + assert d12.lhs == d1 + assert d12.rhs == d2 + +run(testAffineAddExpr) + + +# CHECK-LABEL: TEST: testAffineMulExpr +def testAffineMulExpr(): + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + expr = AffineExpr.get_mul(d1, c2) + # CHECK: d1 * 2 + print(expr) + + # CHECK: d1 * 2 + op = d1 * c2 + print(op) + + assert expr == op + assert expr.lhs == d1 + assert expr.rhs == c2 + +run(testAffineMulExpr) + + +# CHECK-LABEL: TEST: testAffineModExpr +def testAffineModExpr(): + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + expr = AffineExpr.get_mod(d1, c2) + # CHECK: d1 mod 2 + print(expr) + + # CHECK: d1 mod 2 + op = d1 % c2 + print(op) + + assert expr == op + assert expr.lhs == d1 + assert expr.rhs == c2 + +run(testAffineModExpr) + + +# CHECK-LABEL: TEST: testAffineFloorDivExpr +def testAffineFloorDivExpr(): + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + expr = AffineExpr.get_floor_div(d1, c2) + # CHECK: d1 floordiv 2 + print(expr) + + assert expr.lhs == d1 + assert expr.rhs == c2 + +run(testAffineFloorDivExpr) + + +# CHECK-LABEL: TEST: testAffineCeilDivExpr +def testAffineCeilDivExpr(): + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + expr = AffineExpr.get_ceil_div(d1, c2) + # CHECK: d1 ceildiv 2 + print(expr) + + assert expr.lhs == d1 + assert expr.rhs == c2 + +run(testAffineCeilDivExpr) + + +# CHECK-LABEL: TEST: testAffineExprSub +def testAffineExprSub(): + with Context(): + d1 = AffineDimExpr.get(1) + d2 = AffineDimExpr.get(2) + expr = d1 - d2 + # CHECK: d1 - d2 + print(expr) + + assert expr.lhs == d1 + rhs = AffineMulExpr(expr.rhs) + # CHECK: d2 + print(rhs.lhs) + # CHECK: -1 + print(rhs.rhs) + +run(testAffineExprSub) + + +def testClassHierarchy(): + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + add = AffineAddExpr.get(d1, c2) + mul = AffineMulExpr.get(d1, c2) + mod = AffineModExpr.get(d1, c2) + floor_div = AffineFloorDivExpr.get(d1, c2) + ceil_div = AffineCeilDivExpr.get(d1, c2) + + # CHECK: False + print(isinstance(d1, AffineBinaryExpr)) + # CHECK: False + print(isinstance(c2, AffineBinaryExpr)) + # CHECK: True + print(isinstance(add, AffineBinaryExpr)) + # CHECK: True + print(isinstance(mul, AffineBinaryExpr)) + # CHECK: True + print(isinstance(mod, AffineBinaryExpr)) + # CHECK: True + print(isinstance(floor_div, AffineBinaryExpr)) + # CHECK: True + print(isinstance(ceil_div, AffineBinaryExpr)) + + try: + AffineBinaryExpr(d1) + except ValueError as e: + # CHECK: Cannot cast affine expression to AffineBinaryExpr + print(e) + + try: + AffineBinaryExpr(c2) + except ValueError as e: + # CHECK: Cannot cast affine expression to AffineBinaryExpr + print(e) + +run(testClassHierarchy) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1251,6 +1251,27 @@ if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr)) return 13; + if (!mlirAffineExprIsABinary(affineAddExpr)) + return 14; + + // Test other 'IsA' method on affine expressions. + if (!mlirAffineExprIsAConstant(affineConstantExpr)) + return 15; + + if (!mlirAffineExprIsADim(affineDimExpr)) + return 16; + + if (!mlirAffineExprIsASymbol(affineSymbolExpr)) + return 17; + + // Test equality and nullity. + MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5); + if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr)) + return 18; + + if (mlirAffineExprIsNull(affineDimExpr)) + return 19; + return 0; }