diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -38,9 +38,11 @@ /// Wrapper around MlirModule. class PyModule { public: - PyModule(MlirModule module) : module(module) {} + PyModule(MlirContext context, MlirModule module) + : context(context), module(module) {} PyModule(PyModule &) = delete; PyModule(PyModule &&other) { + context = other.context; module = other.module; other.module.ptr = nullptr; } @@ -49,6 +51,7 @@ mlirModuleDestroy(module); } + MlirContext context; MlirModule module; }; @@ -64,6 +67,7 @@ public: PyRegion(MlirContext context, MlirRegion region, bool detached) : context(context), region(region), detached(detached) {} + PyRegion(const PyRegion &) = delete; PyRegion(PyRegion &&other) : context(other.context), region(other.region), detached(other.detached) { other.detached = false; @@ -76,7 +80,9 @@ // Call prior to attaching the region to a parent. // This will transition to the attached state and will throw an exception // if already attached. - void attachToParent(); + // If checkOnly is true, the check is performed but the detached state is not + // modified. + void attachToParent(bool checkOnly = false); MlirContext context; MlirRegion region; @@ -97,6 +103,7 @@ public: PyBlock(MlirContext context, MlirBlock block, bool detached) : context(context), block(block), detached(detached) {} + PyBlock(const PyBlock &) = delete; PyBlock(PyBlock &&other) : context(other.context), block(other.block), detached(other.detached) { other.detached = false; @@ -109,7 +116,9 @@ // Call prior to attaching the block to a parent. // This will transition to the attached state and will throw an exception // if already attached. - void attachToParent(); + // If checkOnly is true, the check is performed but the detached state is not + // modified. + void attachToParent(bool checkOnly = false); MlirContext context; MlirBlock block; @@ -118,6 +127,43 @@ bool detached; }; +/// Wrapper around an MlirOperation. +/// Note that operations can exist in a detached state (where this instance is +/// responsible for clearing) or an attached state (where its owner is +/// responsible). +/// +/// This python wrapper retains a redundant reference to its creating context +/// in order to facilitate checking that parts of the operation hierarchy +/// are only assembled from the same context. +class PyOperation { +public: + PyOperation(MlirContext context, MlirOperation operation, bool detached) + : context(context), operation(operation), detached(detached) {} + PyOperation(const PyOperation &) = delete; + PyOperation(PyOperation &&other) + : context(other.context), operation(other.operation), + detached(other.detached) { + other.detached = false; + } + ~PyOperation() { + if (detached) + mlirOperationDestroy(operation); + } + + // Call prior to attaching the operation to a parent. + // This will transition to the attached state and will throw an exception + // if already attached. + // If checkOnly is true, the check is performed but the detached state is not + // modified. + void attachToParent(bool checkOnly = false); + + MlirContext context; + MlirOperation operation; + +private: + bool detached; +}; + /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. class PyAttribute { diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -52,6 +52,18 @@ static const char kContextCreateRegionDocstring[] = R"(Creates a detached region)"; +static const char kContextCreateOperationDocstring[] = + R"(Creates a detached operation. + +Args: + name: An operation name (as "dialect.operation"). + location: The location to associate with the operation. + results: Optional list of result types the operation returns. + regions: Optional regions to add to the operation. + successors: Optional references to successor blocks. + attributes: Optional list of named attributes (as from Attribute.get_named()). +)"; + static const char kRegionAppendBlockDocstring[] = R"(Appends a block to a region. @@ -174,18 +186,29 @@ // PyBlock, PyRegion, and PyOperation. //------------------------------------------------------------------------------ -void PyRegion::attachToParent() { +void PyRegion::attachToParent(bool checkOnly) { if (!detached) { throw SetPyError(PyExc_ValueError, "Region is already attached to an op"); } - detached = false; + if (!checkOnly) + detached = false; } -void PyBlock::attachToParent() { +void PyBlock::attachToParent(bool checkOnly) { if (!detached) { throw SetPyError(PyExc_ValueError, "Block is already attached to an op"); } - detached = false; + if (!checkOnly) + detached = false; +} + +void PyOperation::attachToParent(bool checkOnly) { + if (!detached) { + throw SetPyError(PyExc_ValueError, + "Operation is already attached to a block"); + } + if (!checkOnly) + detached = false; } //------------------------------------------------------------------------------ @@ -828,6 +851,12 @@ // Mapping of MlirContext py::class_(m, "Context") .def(py::init<>()) + .def( + "create_module", + [](PyMlirContext &self, PyLocation location) { + return PyModule(self.context, mlirModuleCreateEmpty(location.loc)); + }, + py::keep_alive<0, 1>(), py::arg("location")) .def( "parse_module", [](PyMlirContext &self, const std::string module) { @@ -840,7 +869,7 @@ PyExc_ValueError, "Unable to parse module assembly (see diagnostics)"); } - return PyModule(moduleRef); + return PyModule(self.context, moduleRef); }, py::keep_alive<0, 1>(), kContextParseDocstring) .def( @@ -913,10 +942,89 @@ llvm::SmallVector types(pyTypes.begin(), pyTypes.end()); return PyBlock(self.context, - mlirBlockCreate(types.size(), &types[0]), + mlirBlockCreate(types.size(), types.data()), /*detached=*/true); }, - py::keep_alive<0, 1>(), kContextCreateBlockDocstring); + py::keep_alive<0, 1>(), kContextCreateBlockDocstring) + .def( + "create_operation", + // TODO: Add operands once Value is mapped and usable. + [](PyMlirContext &self, std::string name, PyLocation location, + llvm::Optional> results, + llvm::Optional> regions, + llvm::Optional> successors, + llvm::Optional> attributes) { + MlirOperationState state = + mlirOperationStateGet(name.c_str(), location.loc); + // Add results (value type). + // TODO: Verify that types originate from the same context. + if (results) { + llvm::SmallVector mlirResults(results->begin(), + results->end()); + mlirOperationStateAddResults(&state, mlirResults.size(), + mlirResults.data()); + } + // Add owned regions (does not yet transfer ownership). + if (regions) { + llvm::SmallVector mlirRegions; + mlirRegions.reserve(regions->size()); + for (auto *region : *regions) { + // TODO: Verify that regions originate from the same context. + // TODO: Verify whether the python API allows None to be passed. + assert(region && "region cannot be null"); + region->attachToParent(/*checkOnly=*/true); + mlirRegions.push_back(region->region); + } + mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), + mlirRegions.data()); + } + // Add non-owned successor references. + if (successors) { + llvm::SmallVector mlirSuccessors; + mlirSuccessors.reserve(successors->size()); + for (auto *successor : *successors) { + // TODO: Verify that successors originate from the same context. + // TODO: Verify whether the python API allows None to be passed. + assert(successor && "successor cannot be null"); + mlirSuccessors.push_back(successor->block); + } + mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), + mlirSuccessors.data()); + } + // Add attributes. + if (attributes) { + llvm::SmallVector mlirAttributes; + mlirAttributes.reserve(attributes->size()); + for (auto *attr : *attributes) { + // TODO: Verify that attributes originate from the same context. + // TODO: Verify whether the python API allows None to be passed. + assert(attr && "attribute cannot be null"); + mlirAttributes.push_back(attr->namedAttr); + } + mlirOperationStateAddAttributes(&state, mlirAttributes.size(), + mlirAttributes.data()); + } + + // Construct the operation. + MlirOperation operation = mlirOperationCreate(&state); + // Transfer ownership of children. These transfers must succeed by + // way of not being failable or having been verified in advance. + if (regions) { + for (auto it : llvm::enumerate(*regions)) { + it.value()->attachToParent(); + // Reset to the ownership-transferred region. + // TODO: Clean up this attachToParent() method to do exactly + // what is needed here. + it.value()->region = + mlirOperationGetRegion(operation, it.index()); + } + } + return PyOperation(self.context, operation, /*detached=*/true); + }, + py::keep_alive<0, 1>(), py::arg("name"), py::arg("location"), + py::arg("results") = py::none(), py::arg("regions") = py::none(), + py::arg("successors") = py::none(), + py::arg("attributes") = py::none(), kContextCreateOperationDocstring); py::class_(m, "Location").def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; @@ -927,6 +1035,14 @@ // Mapping of Module py::class_(m, "Module") + .def_property_readonly( + "operation", + [](PyModule &self) { + return PyOperation(self.context, + mlirModuleGetOperation(self.module), + /*detached=*/false); + }, + py::keep_alive<0, 1>(), "Accesses the module as an operation") .def( "dump", [](PyModule &self) { @@ -998,6 +1114,20 @@ return PyBlock(self.context, block, /*detached=*/false); }, py::keep_alive<0, 1>(), kBlockNextInRegionDocstring) + // TODO: Remove prepend_operation in favor of a pseudo-list. It is just + // here to test the ownership model while bootstrapping the API. + .def("prepend_operation", + [](PyBlock &self, PyOperation &operation) { + if (!mlirContextEqual(self.context, operation.context)) { + throw SetPyError( + PyExc_ValueError, + "Operation must have been created from the same context as " + "this block"); + } + operation.attachToParent(); + mlirBlockInsertOwnedOperation(self.block, /*pos=*/0, + operation.operation); + }) .def( "__str__", [](PyBlock &self) { @@ -1008,6 +1138,29 @@ }, kTypeStrDunderDocstring); + // Mapping of Operation. + py::class_(m, "Operation") + .def_property_readonly( + "first_region", + [](PyOperation &self) { + if (mlirOperationGetNumRegions(self.operation) == 0) { + throw SetPyError(PyExc_IndexError, "Operation has no regions"); + } + return PyRegion(self.context, + mlirOperationGetRegion(self.operation, 0), + /*detached=*/false); + }, + py::keep_alive<0, 1>(), "Gets the operation's first region") + .def( + "__str__", + [](PyOperation &self) { + PyPrintAccumulator printAccum; + mlirOperationPrint(self.operation, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + kTypeStrDunderDocstring); + // Mapping of Type. py::class_(m, "Attribute") .def( diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -12,8 +12,16 @@ #include #include +#include "llvm/ADT/Optional.h" #include "llvm/ADT/Twine.h" +namespace pybind11 { +namespace detail { +template +struct type_caster> : optional_caster> {}; +} // namespace detail +} // namespace pybind11 + namespace mlir { namespace python { diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -12,6 +12,7 @@ ctx = mlir.ir.Context() t = mlir.ir.F32Type(ctx) region = ctx.create_region() + # TODO: Add a test for a block with no types (was causing an abort) block = ctx.create_block([t, t]) # CHECK: <> print(block) @@ -69,3 +70,68 @@ raise RuntimeError("Expected exception not raised") run(testBlockAppend) + +# CHECK-LABEL: TEST: testOperationCreate +def testOperationCreate(): + ctx = mlir.ir.Context() + loc = ctx.get_unknown_location() + # Empty op. + op = ctx.create_operation("unnamed.empty", loc) + # CHECK: empty op: "unnamed.empty"() : () -> () + print("empty op:", op) + # Attribute containing op. + op = ctx.create_operation("unnamed.with_attr", loc, attributes=[ + mlir.ir.StringAttr.get(ctx, "attrvalue").get_named("attr1"), + mlir.ir.StringAttr.get(ctx, "attrvalue2").get_named("attr2"), + ]) + # CHECK: attr op: "unnamed.with_attr"() {attr1 = "attrvalue", attr2 = "attrvalue2"} : () -> () + print("attr op:", op) + # Op with results. + op = ctx.create_operation("unnamed.with_results", loc, results=[ + mlir.ir.F32Type(ctx), + mlir.ir.F64Type(ctx), + ]) + # CHECK: %0:2 = "unnamed.with_results"() : () -> (f32, f64) + print("results op:", op) + + # Op with regions. + # TODO: DO NOT SUBMIT. Refactor this test to be a more principled test of + # ownership transfer and validity. + region1 = ctx.create_region() + region2 = ctx.create_region() + op = ctx.create_operation("unnamed.with_regions", loc, regions=[ + region1, + region2, + ]) + # CHECK: region op: "unnamed.with_regions"() ( { + # CHECK: }, { + # CHECK: }) : () -> () + print("region op:", op) + block1 = ctx.create_block([]) + region1.append_block(block1) # Should not crash + # CHECK: block1: ^bb0: // no predecessors + print("block1:", block1) # Should not crash + block1.prepend_operation(ctx.create_operation("unnamed.inner", loc)) + # CHECK: region op with block: "unnamed.with_regions"() ( { + # CHECK: "unnamed.inner"() : () -> () + print("region op with block:", op) + + # TODO: Test operands once it is possible to create them. + # TODO: Test successor blocks once it is possible to create them. + +run(testOperationCreate) + +# CHECK-LABEL: TEST: testOperationInModule +def testOperationInModule(): + ctx = mlir.ir.Context() + loc = ctx.get_unknown_location() + m = ctx.create_module(loc) + block = m.operation.first_region.first_block + op = ctx.create_operation("unnamed.empty", loc) + block.prepend_operation(op) + # CHECK: module { + # CHECK: "unnamed.empty"() : () -> () + # CHECK: } + print(m) + +run(testOperationInModule)