diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -595,6 +595,18 @@ /// ownership is transferred to the block of the other operation. MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other); + +typedef enum MlirWalkOrder { + MlirWalkPreOrder, + MlirWalkPostOrder +} MlirWalkOrder; + +typedef void (*MlirOperationWalkCallback)(MlirOperation, void *); + +MLIR_CAPI_EXPORTED +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder); + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -622,6 +622,11 @@ return numInvalidated; } +void PyMlirContext::setOperationInvalid(MlirOperation op) { + if (liveOperations.contains(op.ptr)) + liveOperations[op.ptr].second->setInvalid(); +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -209,6 +209,11 @@ /// place. size_t clearLiveOperations(); + /// Sets an operation invalid. This is useful for when some non-bindings + /// code destroys the operation and the bindings need to made aware. For + /// example, in the case when pass manager is run. + void setOperationInvalid(MlirOperation op); + /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); @@ -262,6 +267,24 @@ friend class PyOperation; }; +/// Bundles together a MlirOperationWalkCallback and PyMlirContext, +/// the latter is to be passed as *userData to the MlirOperationWalkCallback +/// for callbacks that need to manipulate the context. +class PyOperationWalkerCallback { +public: + PyOperationWalkerCallback(PyMlirContext *context, + MlirOperationWalkCallback callback) + : context(context), callback(callback) {} + + PyMlirContext *getContext() { return context; } + + MlirOperationWalkCallback getCallback() { return callback; } + +private: + PyMlirContext *context; + MlirOperationWalkCallback callback; +}; + /// Used in function arguments when None should resolve to the current context /// manager set instance. class DefaultingPyMlirContext diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -117,6 +117,34 @@ .def( "run", [](PyPassManager &passManager, PyOperationBase &op) { + // Mark all ops below the op that the passmanager will be rooted at + // as invalid. + auto *context = op.getOperation().getContext().get(); + MlirOperationWalkCallback invalidatingCallback = + [](MlirOperation op, void *userData) { + PyOperationWalkerCallback *walker = + static_cast(userData); + walker->getContext()->setOperationInvalid(op); + }; + PyOperationWalkerCallback operationWalker(context, + invalidatingCallback); + auto numRegions = + mlirOperationGetNumRegions(op.getOperation().get()); + for (int i = 0; i < numRegions; ++i) { + MlirRegion region = + mlirOperationGetRegion(op.getOperation().get(), i); + MlirBlock block = mlirRegionGetFirstBlock(region); + while (!mlirBlockIsNull(block)) { + MlirOperation childOp = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(childOp)) { + mlirOperationWalk(childOp, operationWalker.getCallback(), + &operationWalker, MlirWalkPostOrder); + childOp = mlirOperationGetNextInBlock(childOp); + } + block = mlirBlockGetNextInRegion(block); + } + } + // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( passManager.get(), op.getOperation().get()); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" @@ -554,6 +555,14 @@ return unwrap(op)->moveBefore(unwrap(other)); } +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder) { + mlir::detail::walk( + unwrap(op), + [callback, userData](Operation *op) { callback(wrap(op), userData); }, + static_cast(walkOrder)); +} + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -4,6 +4,8 @@ from mlir.ir import * from mlir.passmanager import * from mlir.dialects.func import FuncOp +from mlir.dialects.builtin import ModuleOp + # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -33,6 +35,7 @@ run(testCapsule) + # CHECK-LABEL: TEST: testConstruct @run def testConstruct(): @@ -68,6 +71,7 @@ run(testParseSuccess) + # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSpacedPipeline def testParseSpacedPipeline(): @@ -84,6 +88,7 @@ run(testParseSpacedPipeline) + # Verify failure on unregistered pass. # CHECK-LABEL: TEST: testParseFail def testParseFail(): @@ -102,6 +107,7 @@ run(testParseFail) + # Check that adding to a pass manager works # CHECK-LABEL: TEST: testAdd @run @@ -147,6 +153,7 @@ # CHECK: func.return , 1 run(testRunPipeline) + # CHECK-LABEL: TEST: testRunPipelineError @run def testRunPipelineError(): @@ -162,4 +169,24 @@ # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () # CHECK: > - print(f"Exception: <{e}>") + log(f"Exception: <{e}>") + + +# CHECK-LABEL: TEST: testPostPassOpInvalidation +@run +def testPostPassOpInvalidation(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + arith.constant 10 + } + """ + ) + const_op = module.body.operations[0] + log(const_op) + PassManager.parse("builtin.module(canonicalize)").run(module) + try: + log(const_op) + except RuntimeError as e: + log(e)