diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -85,6 +85,9 @@ /** Creates an MLIR context and transfers its ownership to the caller. */ MlirContext mlirContextCreate(); +/** Checks if two contexts are equal. */ +int mlirContextEqual(MlirContext ctx1, MlirContext ctx2); + /** Takes an MLIR context owned by the caller and destroys it. */ void mlirContextDestroy(MlirContext context); @@ -315,6 +318,9 @@ /** Parses a type. The type is owned by the context. */ MlirType mlirTypeParseGet(MlirContext context, const char *type); +/** Gets the context that a type was created with. */ +MlirContext mlirTypeGetContext(MlirType type); + /** Checks whether a type is null. */ inline int mlirTypeIsNull(MlirType type) { return !type.ptr; } diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -28,6 +28,13 @@ MlirContext context; }; +/// Wrapper around an MlirLocation. +class PyLocation { +public: + PyLocation(MlirLocation loc) : loc(loc) {} + MlirLocation loc; +}; + /// Wrapper around MlirModule. class PyModule { public: @@ -45,6 +52,72 @@ MlirModule module; }; +/// Wrapper around an MlirRegion. +/// Note that region can exist in a detached state (where this instance is +/// responsible for clearing) or an attached state (where its owner is +/// responsible). +/// +/// This python wrapper retains a redundant reference to its creating context +/// in order to facilitate checking that parts of the operation hierarchy +/// are only assembled from the same context. +class PyRegion { +public: + PyRegion(MlirContext context, MlirRegion region, bool detached) + : context(context), region(region), detached(detached) {} + PyRegion(PyRegion &&other) + : context(other.context), region(other.region), detached(other.detached) { + other.detached = false; + } + ~PyRegion() { + if (detached) + mlirRegionDestroy(region); + } + + // Call prior to attaching the region to a parent. + // This will transition to the attached state and will throw an exception + // if already attached. + void attachToParent(); + + MlirContext context; + MlirRegion region; + +private: + bool detached; +}; + +/// Wrapper around an MlirBlock. +/// Note that blocks can exist in a detached state (where this instance is +/// responsible for clearing) or an attached state (where its owner is +/// responsible). +/// +/// This python wrapper retains a redundant reference to its creating context +/// in order to facilitate checking that parts of the operation hierarchy +/// are only assembled from the same context. +class PyBlock { +public: + PyBlock(MlirContext context, MlirBlock block, bool detached) + : context(context), block(block), detached(detached) {} + PyBlock(PyBlock &&other) + : context(other.context), block(other.block), detached(other.detached) { + other.detached = false; + } + ~PyBlock() { + if (detached) + mlirBlockDestroy(block); + } + + // Call prior to attaching the block to a parent. + // This will transition to the attached state and will throw an exception + // if already attached. + void attachToParent(); + + MlirContext context; + MlirBlock block; + +private: + bool detached; +}; + /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. class PyAttribute { @@ -84,6 +157,7 @@ public: PyType(MlirType type) : type(type) {} bool operator==(const PyType &other); + operator MlirType() const { return type; } MlirType type; }; diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -28,13 +28,59 @@ See also: https://mlir.llvm.org/docs/LangRef/ )"; -static const char kContextParseType[] = R"(Parses the assembly form of a type. +static const char kContextParseTypeDocstring[] = + R"(Parses the assembly form of a type. Returns a Type object or raises a ValueError if the type cannot be parsed. See also: https://mlir.llvm.org/docs/LangRef/#type-system )"; +static const char kContextGetUnknownLocationDocstring[] = + R"(Gets a Location representing an unknown location)"; + +static const char kContextGetFileLocationDocstring[] = + R"(Gets a Location representing a file, line and column)"; + +static const char kContextCreateBlockDocstring[] = + R"(Creates a detached block)"; + +static const char kContextCreateRegionDocstring[] = + R"(Creates a detached region)"; + +static const char kRegionAppendBlockDocstring[] = + R"(Appends a block to a region. + +Raises: + ValueError: If the block is already attached to another region. +)"; + +static const char kRegionInsertBlockDocstring[] = + R"(Inserts a block at a postiion in a region. + +Raises: + ValueError: If the block is already attached to another region. +)"; + +static const char kRegionFirstBlockDocstring[] = + R"(Gets the first block in a region. + +Blocks can also be accessed via the `blocks` container. + +Raises: + IndexError: If the region has no blocks. +)"; + +static const char kBlockNextInRegionDocstring[] = + R"(Gets the next block in the enclosing region. + +Blocks can also be accessed via the `blocks` container of the owning region. +This method exists to mirror the lower level API and should not be preferred. + +Raises: + IndexError: If there are no further blocks. +)"; + static const char kOperationStrDunderDocstring[] = R"(Prints the assembly form of the operation with default options. @@ -106,6 +152,24 @@ } // namespace +//------------------------------------------------------------------------------ +// PyBlock, PyRegion, and PyOperation. +//------------------------------------------------------------------------------ + +void PyRegion::attachToParent() { + if (!detached) { + throw SetPyError(PyExc_ValueError, "Region is already attached to an op"); + } + detached = false; +} + +void PyBlock::attachToParent() { + if (!detached) { + throw SetPyError(PyExc_ValueError, "Block is already attached to an op"); + } + detached = false; +} + //------------------------------------------------------------------------------ // PyAttribute. //------------------------------------------------------------------------------ @@ -454,7 +518,59 @@ } return PyType(type); }, - py::keep_alive<0, 1>(), kContextParseType); + py::keep_alive<0, 1>(), kContextParseTypeDocstring) + .def( + "get_unknown_location", + [](PyMlirContext &self) { + return PyLocation(mlirLocationUnknownGet(self.context)); + }, + py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring) + .def( + "get_file_location", + [](PyMlirContext &self, std::string filename, int line, int col) { + return PyLocation(mlirLocationFileLineColGet( + self.context, filename.c_str(), line, col)); + }, + py::keep_alive<0, 1>(), kContextGetFileLocationDocstring, + py::arg("filename"), py::arg("line"), py::arg("col")) + .def( + "create_region", + [](PyMlirContext &self) { + // The creating context is explicitly captured on regions to + // facilitate illegal assemblies of objects from multiple contexts + // that would invalidate the memory model. + return PyRegion(self.context, mlirRegionCreate(), + /*detached=*/true); + }, + py::keep_alive<0, 1>(), kContextCreateRegionDocstring) + .def( + "create_block", + [](PyMlirContext &self, std::vector pyTypes) { + // In order for the keep_alive extend the proper lifetime, all + // types must be from the same context. + for (auto pyType : pyTypes) { + if (!mlirContextEqual(mlirTypeGetContext(pyType.type), + self.context)) { + throw SetPyError( + PyExc_ValueError, + "All types used to construct a block must be from " + "the same context as the block"); + } + } + llvm::SmallVector types(pyTypes.begin(), + pyTypes.end()); + return PyBlock(self.context, + mlirBlockCreate(types.size(), &types[0]), + /*detached=*/true); + }, + py::keep_alive<0, 1>(), kContextCreateBlockDocstring); + + py::class_(m, "Location").def("__repr__", [](PyLocation &self) { + PyPrintAccumulator printAccum; + mlirLocationPrint(self.loc, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }); // Mapping of Module py::class_(m, "Module") @@ -475,6 +591,70 @@ }, kOperationStrDunderDocstring); + // Mapping of PyRegion. + py::class_(m, "Region") + .def( + "append_block", + [](PyRegion &self, PyBlock &block) { + if (!mlirContextEqual(self.context, block.context)) { + throw SetPyError( + PyExc_ValueError, + "Block must have been created from the same context as " + "this region"); + } + + block.attachToParent(); + mlirRegionAppendOwnedBlock(self.region, block.block); + }, + kRegionAppendBlockDocstring) + .def( + "insert_block", + [](PyRegion &self, int pos, PyBlock &block) { + if (!mlirContextEqual(self.context, block.context)) { + throw SetPyError( + PyExc_ValueError, + "Block must have been created from the same context as " + "this region"); + } + block.attachToParent(); + // TODO: Make this return a failure and raise if out of bounds. + mlirRegionInsertOwnedBlock(self.region, pos, block.block); + }, + kRegionInsertBlockDocstring) + .def_property_readonly( + "first_block", + [](PyRegion &self) { + MlirBlock block = mlirRegionGetFirstBlock(self.region); + if (mlirBlockIsNull(block)) { + throw SetPyError(PyExc_IndexError, "Region has no blocks"); + } + return PyBlock(self.context, block, /*detached=*/false); + }, + kRegionFirstBlockDocstring); + + // Mapping of PyBlock. + py::class_(m, "Block") + .def_property_readonly( + "next_in_region", + [](PyBlock &self) { + MlirBlock block = mlirBlockGetNextInRegion(self.block); + if (mlirBlockIsNull(block)) { + throw SetPyError(PyExc_IndexError, + "Attempt to read past last block"); + } + return PyBlock(self.context, block, /*detached=*/false); + }, + py::keep_alive<0, 1>(), kBlockNextInRegionDocstring) + .def( + "__str__", + [](PyBlock &self) { + PyPrintAccumulator printAccum; + mlirBlockPrint(self.block, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + kTypeStrDunderDocstring); + // Mapping of Type. py::class_(m, "Attribute") .def( diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -10,6 +10,7 @@ #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #include +#include #include "llvm/ADT/Twine.h" diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -55,6 +55,10 @@ return wrap(context); } +int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { + return unwrap(ctx1) == unwrap(ctx2); +} + void mlirContextDestroy(MlirContext context) { delete unwrap(context); } /* ========================================================================== */ @@ -350,6 +354,10 @@ return wrap(mlir::parseType(type, unwrap(context))); } +MlirContext mlirTypeGetContext(MlirType type) { + return wrap(unwrap(type).getContext()); +} + int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { diff --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_location.py @@ -0,0 +1,31 @@ +# RUN: %PYTHON %s | FileCheck %s + +import mlir + +def run(f): + print("\nTEST:", f.__name__) + f() + +# CHECK-LABEL: TEST: testUnknown +def testUnknown(): + ctx = mlir.ir.Context() + loc = ctx.get_unknown_location() + # CHECK: unknown str: loc(unknown) + print("unknown str:", str(loc)) + # CHECK: unknown repr: loc(unknown) + print("unknown repr:", repr(loc)) + +run(testUnknown) + + +# CHECK-LABEL: TEST: testFileLineCol +def testFileLineCol(): + ctx = mlir.ir.Context() + loc = ctx.get_file_location("foo.txt", 123, 56) + # CHECK: file str: loc("foo.txt":123:56) + print("file str:", str(loc)) + # CHECK: file repr: loc("foo.txt":123:56) + print("file repr:", repr(loc)) + +run(testFileLineCol) + diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -0,0 +1,71 @@ +# RUN: %PYTHON %s | FileCheck %s + +import mlir + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testDetachedRegionBlock +def testDetachedRegionBlock(): + ctx = mlir.ir.Context() + t = mlir.ir.F32Type(ctx) + region = ctx.create_region() + block = ctx.create_block([t, t]) + # CHECK: <> + print(block) + +run(testDetachedRegionBlock) + + +# CHECK-LABEL: TEST: testBlockTypeContextMismatch +def testBlockTypeContextMismatch(): + c1 = mlir.ir.Context() + c2 = mlir.ir.Context() + t1 = mlir.ir.F32Type(c1) + t2 = mlir.ir.F32Type(c2) + try: + block = c1.create_block([t1, t2]) + except ValueError as e: + # CHECK: ERROR: All types used to construct a block must be from the same context as the block + print("ERROR:", e) + +run(testBlockTypeContextMismatch) + + +# CHECK-LABEL: TEST: testBlockAppend +def testBlockAppend(): + ctx = mlir.ir.Context() + t = mlir.ir.F32Type(ctx) + region = ctx.create_region() + try: + region.first_block + except IndexError: + pass + else: + raise RuntimeError("Expected exception not raised") + + block = ctx.create_block([t, t]) + region.append_block(block) + try: + region.append_block(block) + except ValueError: + pass + else: + raise RuntimeError("Expected exception not raised") + + block2 = ctx.create_block([t]) + region.insert_block(1, block2) + # CHECK: <> + block_first = region.first_block + print(block_first) + block_next = block_first.next_in_region + try: + block_next = block_next.next_in_region + except IndexError: + pass + else: + raise RuntimeError("Expected exception not raised") + +run(testBlockAppend)