diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -595,6 +595,14 @@ /// ownership is transferred to the block of the other operation. MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other); +enum WalkOrder { PreOrder, PostOrder }; + +typedef void (*MlirOperationWalkCallback)(MlirOperation, void *); + +MLIR_CAPI_EXPORTED +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, enum WalkOrder walkOrder); + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -622,6 +622,11 @@ return numInvalidated; } +void PyMlirContext::setOperationInvalid(MlirOperation op) { + if (liveOperations.contains(op.ptr)) + liveOperations[op.ptr].second->setInvalid(); +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -208,6 +208,7 @@ /// corrupt by holding references they shouldn't have accessed in the first /// place. size_t clearLiveOperations(); + void setOperationInvalid(MlirOperation op); /// Gets the count of live modules associated with this context. /// Used for testing. @@ -262,6 +263,23 @@ friend class PyOperation; }; +class PyOperationWalker { +public: + PyOperationWalker(PyMlirContext *context) : context(context) {} + + void *getUserData() { return this; } + + MlirOperationWalkCallback getCallback() { + return [](MlirOperation op, void *userData) { + PyOperationWalker *walker = static_cast(userData); + walker->context->setOperationInvalid(op); + }; + } + +private: + PyMlirContext *context; +}; + /// Used in function arguments when None should resolve to the current context /// manager set instance. class DefaultingPyMlirContext diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -118,6 +118,24 @@ "run", [](PyPassManager &passManager, PyOperationBase &op) { PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); + auto *context = op.getOperation().getContext().get(); + PyOperationWalker operationWalker(context); + auto numRegions = + mlirOperationGetNumRegions(op.getOperation().get()); + for (int i = 0; i < numRegions; ++i) { + MlirRegion region = + mlirOperationGetRegion(op.getOperation().get(), i); + MlirBlock block = mlirRegionGetFirstBlock(region); + while (!mlirBlockIsNull(block)) { + MlirOperation childOp = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(childOp)) { + mlirOperationWalk(childOp, operationWalker.getCallback(), + &operationWalker, ::PostOrder); + childOp = mlirOperationGetNextInBlock(childOp); + } + block = mlirBlockGetNextInRegion(block); + } + } MlirLogicalResult status = mlirPassManagerRunOnOp( passManager.get(), op.getOperation().get()); if (mlirLogicalResultIsFailure(status)) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" @@ -554,6 +555,17 @@ return unwrap(op)->moveBefore(unwrap(other)); } +void mlirOperationWalk(MlirOperation op, + const MlirOperationWalkCallback callback, void *userData, + ::WalkOrder walkOrder) { + if (walkOrder == ::WalkOrder::PreOrder) + unwrap(op)->walk( + [callback, userData](Operation *op) { callback(wrap(op), userData); }); + else + unwrap(op)->walk( + [callback, userData](Operation *op) { callback(wrap(op), userData); }); +} + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -4,6 +4,8 @@ from mlir.ir import * from mlir.passmanager import * from mlir.dialects.func import FuncOp +from mlir.dialects.builtin import ModuleOp + # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -33,6 +35,7 @@ run(testCapsule) + # CHECK-LABEL: TEST: testConstruct @run def testConstruct(): @@ -68,6 +71,7 @@ run(testParseSuccess) + # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSpacedPipeline def testParseSpacedPipeline(): @@ -84,6 +88,7 @@ run(testParseSpacedPipeline) + # Verify failure on unregistered pass. # CHECK-LABEL: TEST: testParseFail def testParseFail(): @@ -102,6 +107,7 @@ run(testParseFail) + # Check that adding to a pass manager works # CHECK-LABEL: TEST: testAdd @run @@ -147,6 +153,27 @@ # CHECK: func.return , 1 run(testRunPipeline) + +# CHECK-LABEL: TEST: testPostPassOpInvalidation +@run +def testPostPassOpInvalidation(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + arith.constant 10 + } + """ + ) + const_op = module.body.operations[0] + print(const_op) + PassManager.parse("builtin.module(canonicalize)").run(module) + try: + print(const_op) + except RuntimeError as e: + print(e) + + # CHECK-LABEL: TEST: testRunPipelineError @run def testRunPipelineError():