diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -114,7 +114,8 @@ } for (auto &block : region) { for (auto &nestedOp : block) - walk(&nestedOp, callback, order); + if (walk(&nestedOp, callback, order).wasInterrupted()) + return WalkResult::interrupt(); } if (order == WalkOrder::PostOrder) { if (callback(®ion).wasInterrupted()) @@ -140,7 +141,8 @@ return WalkResult::interrupt(); } for (auto &nestedOp : block) - walk(&nestedOp, callback, order); + if (walk(&nestedOp, callback, order).wasInterrupted()) + return WalkResult::interrupt(); if (order == WalkOrder::PostOrder) { if (callback(&block).wasInterrupted()) return WalkResult::interrupt(); diff --git a/mlir/test/IR/generic-block-visitors-interrupt.mlir b/mlir/test/IR/generic-block-visitors-interrupt.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/generic-block-visitors-interrupt.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt -test-generic-ir-block-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +func.func @main(%arg0: f32) -> f32 { + %v1 = "foo"() {interrupt = true} : () -> f32 + %v2 = arith.addf %v1, %arg0 : f32 + return %v2 : f32 +} + +// CHECK: step 0 walk was interrupted diff --git a/mlir/test/IR/generic-region-visitors-interrupt.mlir b/mlir/test/IR/generic-region-visitors-interrupt.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/generic-region-visitors-interrupt.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt -test-generic-ir-region-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +func.func @main(%arg0: f32) -> f32 { + %v1 = "foo"() {interrupt = true} : () -> f32 + %v2 = arith.addf %v1, %arg0 : f32 + return %v2 : f32 +} + +// CHECK: step 0 walk was interrupted diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp --- a/mlir/test/lib/IR/TestVisitorsGeneric.cpp +++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp @@ -113,6 +113,73 @@ } }; +struct TestGenericIRBlockVisitorInterruptPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestGenericIRBlockVisitorInterruptPass) + + StringRef getArgument() const final { + return "test-generic-ir-block-visitors-interrupt"; + } + StringRef getDescription() const final { + return "Test generic IR visitors with interrupts, starting with Blocks."; + } + + void runOnOperation() override { + int stepNo = 0; + + auto walker = [&](Block *block) { + for (Operation &op : *block) + for (OpResult result : op.getResults()) + if (Operation *definingOp = result.getDefiningOp()) + if (definingOp->getAttrOfType("interrupt")) + return WalkResult::interrupt(); + + llvm::outs() << "step " << stepNo++ << "\n"; + return WalkResult::advance(); + }; + + auto result = getOperation()->walk(walker); + if (result.wasInterrupted()) + llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; + } +}; + +struct TestGenericIRRegionVisitorInterruptPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestGenericIRRegionVisitorInterruptPass) + + StringRef getArgument() const final { + return "test-generic-ir-region-visitors-interrupt"; + } + StringRef getDescription() const final { + return "Test generic IR visitors with interrupts, starting with Regions."; + } + + void runOnOperation() override { + int stepNo = 0; + + auto walker = [&](Region *region) { + for (Block &block : *region) + for (Operation &op : block) + for (OpResult result : op.getResults()) + if (Operation *definingOp = result.getDefiningOp()) + if (definingOp->getAttrOfType("interrupt")) + return WalkResult::interrupt(); + + llvm::outs() << "step " << stepNo++ << "\n"; + return WalkResult::advance(); + }; + + auto result = getOperation()->walk(walker); + if (result.wasInterrupted()) + llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; + } +}; + } // namespace namespace mlir { @@ -120,6 +187,8 @@ void registerTestGenericIRVisitorsPass() { PassRegistration(); PassRegistration(); + PassRegistration(); + PassRegistration(); } } // namespace test