diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -118,10 +118,12 @@ /// /// Returns success if the iterative process converged and no more patterns can /// be matched. `changed` is set to true if the IR was modified at all. +/// `allOpsErased` is set to true if all ops in `ops` were erased. LogicalResult applyOpPatternsAndFold(ArrayRef ops, const FrozenRewritePatternSet &patterns, GreedyRewriteStrictness strictMode, - bool *changed = nullptr); + bool *changed = nullptr, + bool *allErased = nullptr); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -17,6 +17,7 @@ #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ScopedPrinter.h" @@ -584,8 +585,11 @@ /// /// Note that ops in `ops` could be erased as a result of folding, becoming /// dead, or via pattern rewrites. The return value indicates convergence. - LogicalResult simplifyLocally(ArrayRef op, - bool *changed = nullptr); + /// + /// All `ops` that survived the rewrite are stored in `surviving`. + LogicalResult + simplifyLocally(ArrayRef ops, bool *changed = nullptr, + llvm::SmallDenseSet *surviving = nullptr); void addToWorklist(Operation *op) override { if (strictMode == GreedyRewriteStrictness::AnyOp || @@ -602,6 +606,8 @@ void notifyOperationRemoved(Operation *op) override { GreedyPatternRewriteDriver::notifyOperationRemoved(op); + if (survivingOps) + survivingOps->erase(op); if (strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); } @@ -615,13 +621,25 @@ /// depending on `strictMode`. This set is not maintained when `strictMode` /// is GreedyRewriteStrictness::AnyOp. llvm::SmallDenseSet strictModeFilteredOps; + + /// An optional set of ops that survived the rewrite. This set is populated + /// at the beginning of `simplifyLocally` with the inititally provided list + /// of ops. + llvm::SmallDenseSet *survivingOps = nullptr; }; } // namespace -LogicalResult -MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, - bool *changed) { +LogicalResult MultiOpPatternRewriteDriver::simplifyLocally( + ArrayRef ops, bool *changed, + llvm::SmallDenseSet *surviving) { + auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; }); + if (surviving) { + survivingOps = surviving; + survivingOps->clear(); + survivingOps->insert(ops.begin(), ops.end()); + } + if (strictMode != GreedyRewriteStrictness::AnyOp) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); @@ -726,15 +744,22 @@ LogicalResult mlir::applyOpPatternsAndFold( ArrayRef ops, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, bool *changed) { + GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) { if (ops.empty()) { if (changed) *changed = false; + if (allErased) + *allErased = true; return success(); } // Start the pattern driver. MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, strictMode); - return driver.simplifyLocally(ops, changed); + llvm::SmallDenseSet surviving; + LogicalResult converged = + driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr); + if (allErased) + *allErased = surviving.empty(); + return converged; } diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -7,6 +7,7 @@ // RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX // CHECK-EN-LABEL: func @test_erase +// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true} // CHECK-EN: test.arg0 // CHECK-EN: test.arg1 // CHECK-EN-NOT: test.erase_op @@ -20,6 +21,7 @@ // ----- // CHECK-EN-LABEL: func @test_insert_same_op +// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true} // CHECK-EN: "test.insert_same_op"() {skip = true} // CHECK-EN: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { @@ -30,6 +32,7 @@ // ----- // CHECK-EN-LABEL: func @test_replace_with_new_op +// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} // CHECK-EN: %[[n:.*]] = "test.new_op" // CHECK-EN: "test.dummy_user"(%[[n]]) // CHECK-EN: "test.dummy_user"(%[[n]]) @@ -43,10 +46,12 @@ // ----- // CHECK-EN-LABEL: func @test_replace_with_erase_op +// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} // CHECK-EN-NOT: test.replace_with_new_op // CHECK-EN-NOT: test.erase_op // CHECK-EX-LABEL: func @test_replace_with_erase_op +// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} // CHECK-EX-NOT: test.replace_with_new_op // CHECK-EX: test.erase_op func.func @test_replace_with_erase_op() { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -254,8 +254,9 @@ } void runOnOperation() override { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + MLIRContext *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + patterns.add(ctx); SmallVector ops; getOperation()->walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); @@ -279,7 +280,14 @@ // Check if these transformations introduce visiting of operations that // are not in the `ops` set (The new created ops are valid). An invalid // operation will trigger the assertion while processing. - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode); + bool changed = false; + bool allErased = false; + (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode, + &changed, &allErased); + Builder b(ctx); + getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); + getOperation()->setAttr("pattern_driver_all_erased", + b.getBoolAttr(allErased)); } Option strictMode{