diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -595,6 +595,21 @@ /// ownership is transferred to the block of the other operation. MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other); + +typedef enum MlirWalkOrder { + MlirWalkPreOrder, + MlirWalkPostOrder +} MlirWalkOrder; + +typedef void (*MlirOperationWalkCallback)(MlirOperation, void *); + +/// Walks operation `op` in `walkOrder` and calls `callback` on that operation. +/// `*userData` is passed to the callback as well and can be used to tunnel some +/// some context or other data into the callback. +MLIR_CAPI_EXPORTED +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder); + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -622,6 +622,11 @@ return numInvalidated; } +void PyMlirContext::setOperationInvalid(MlirOperation op) { + if (liveOperations.contains(op.ptr)) + liveOperations[op.ptr].second->setInvalid(); +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -209,6 +209,11 @@ /// place. size_t clearLiveOperations(); + /// Sets an operation invalid. This is useful for when some non-bindings + /// code destroys the operation and the bindings need to made aware. For + /// example, in the case when pass manager is run. + void setOperationInvalid(MlirOperation op); + /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir-c/Pass.h" namespace py = pybind11; +using namespace py::literals; using namespace mlir; using namespace mlir::python; @@ -63,8 +64,7 @@ mlirStringRefCreate(anchorOp.data(), anchorOp.size())); return new PyPassManager(passManager); }), - py::arg("anchor_op") = py::str("any"), - py::arg("context") = py::none(), + "anchor_op"_a = py::str("any"), "context"_a = py::none(), "Create a new PassManager for the current (or provided) Context.") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) @@ -82,7 +82,7 @@ [](PyPassManager &passManager, bool enable) { mlirPassManagerEnableVerifier(passManager.get(), enable); }, - py::arg("enable"), "Enable / disable verify-each.") + "enable"_a, "Enable / disable verify-each.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { @@ -96,7 +96,7 @@ throw py::value_error(std::string(errorMsg.join())); return new PyPassManager(passManager); }, - py::arg("pipeline"), py::arg("context") = py::none(), + "pipeline"_a, "context"_a = py::none(), "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") @@ -111,12 +111,40 @@ if (mlirLogicalResultIsFailure(status)) throw py::value_error(std::string(errorMsg.join())); }, - py::arg("pipeline"), + "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyOperationBase &op) { + [](PyPassManager &passManager, PyOperationBase &op, + bool invalidateOps) { + if (invalidateOps) { + // Mark all ops below the op that the passmanager will be rooted + // at as invalid. + auto *context = op.getOperation().getContext().get(); + MlirOperationWalkCallback invalidatingCallback = + [](MlirOperation op, void *userData) { + PyMlirContext *context = + static_cast(userData); + context->setOperationInvalid(op); + }; + auto numRegions = + mlirOperationGetNumRegions(op.getOperation().get()); + for (int i = 0; i < numRegions; ++i) { + MlirRegion region = + mlirOperationGetRegion(op.getOperation().get(), i); + for (MlirBlock block = mlirRegionGetFirstBlock(region); + !mlirBlockIsNull(block); + block = mlirBlockGetNextInRegion(block)) + for (MlirOperation childOp = + mlirBlockGetFirstOperation(block); + !mlirOperationIsNull(childOp); + childOp = mlirOperationGetNextInBlock(childOp)) + mlirOperationWalk(childOp, invalidatingCallback, context, + MlirWalkPostOrder); + } + } + // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( passManager.get(), op.getOperation().get()); @@ -124,7 +152,7 @@ throw MLIRError("Failure while executing pass pipeline", errors.take()); }, - py::arg("operation"), + "operation"_a, "invalidate_ops"_a = true, "Run the pass manager on the provided operation, raising an " "MLIRError on failure.") .def( diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" @@ -554,6 +555,20 @@ return unwrap(op)->moveBefore(unwrap(other)); } +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder) { + switch (walkOrder) { + + case MlirWalkPreOrder: + unwrap(op)->walk( + [callback, userData](Operation *op) { callback(wrap(op), userData); }); + break; + case MlirWalkPostOrder: + unwrap(op)->walk( + [callback, userData](Operation *op) { callback(wrap(op), userData); }); + } +} + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -4,6 +4,8 @@ from mlir.ir import * from mlir.passmanager import * from mlir.dialects.func import FuncOp +from mlir.dialects.builtin import ModuleOp + # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -33,6 +35,7 @@ run(testCapsule) + # CHECK-LABEL: TEST: testConstruct @run def testConstruct(): @@ -68,6 +71,7 @@ run(testParseSuccess) + # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSpacedPipeline def testParseSpacedPipeline(): @@ -84,6 +88,7 @@ run(testParseSpacedPipeline) + # Verify failure on unregistered pass. # CHECK-LABEL: TEST: testParseFail def testParseFail(): @@ -102,6 +107,7 @@ run(testParseFail) + # Check that adding to a pass manager works # CHECK-LABEL: TEST: testAdd @run @@ -147,6 +153,7 @@ # CHECK: func.return , 1 run(testRunPipeline) + # CHECK-LABEL: TEST: testRunPipelineError @run def testRunPipelineError(): @@ -162,4 +169,46 @@ # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () # CHECK: > - print(f"Exception: <{e}>") + log(f"Exception: <{e}>") + + +# CHECK-LABEL: TEST: testPostPassOpInvalidation +@run +def testPostPassOpInvalidation(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + arith.constant 10 + func.func @foo() { + arith.constant 10 + return + } + } + """ + ) + outer_const_op = module.body.operations[0] + # CHECK: %c10_i64 = arith.constant 10 : i64 + log(outer_const_op) + inner_const_op = module.body.operations[1].body.blocks[0].operations[0] + # CHECK: %c10_i64_0 = arith.constant 10 : i64 + log(inner_const_op) + + PassManager.parse("builtin.module(canonicalize)").run(module) + try: + log(outer_const_op) + except RuntimeError as e: + # CHECK: the operation has been invalidated + log(e) + try: + log(inner_const_op) + except RuntimeError as e: + # CHECK: the operation has been invalidated + log(e) + + # CHECK: module { + # CHECK-LABEL: func.func @foo() { + # CHECK: return + # CHECK: } + # CHECK: } + log(module)