diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -116,13 +116,14 @@ /// or via pattern rewrites. If more far reaching simplification is desired, /// applyPatternsAndFoldGreedily should be used. /// -/// `changed` is set to true if the IR was modified at all. Returns success if /// Returns success if the iterative process converged and no more patterns can -/// be matched. +/// be matched. `changed` is set to true if the IR was modified at all. +/// `allOpsErased` is set to true if all ops in `ops` were erased. LogicalResult applyOpPatternsAndFold(ArrayRef ops, const FrozenRewritePatternSet &patterns, GreedyRewriteStrictness strictMode, - bool *changed = nullptr); + bool *changed = nullptr, + bool *allOpsErased = nullptr); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -579,8 +579,11 @@ /// /// Note that ops in `ops` could be erased as a result of folding, becoming /// dead, or via pattern rewrites. The return value indicates convergence. + /// + /// All erased ops are stored in `erased`. LogicalResult simplifyLocally(ArrayRef op, - bool *changed = nullptr); + bool *changed = nullptr, + DenseSet *erased = nullptr); void addToWorklist(Operation *op) override { if (strictMode == GreedyRewriteStrictness::AnyOp || @@ -597,6 +600,8 @@ void notifyOperationRemoved(Operation *op) override { GreedyPatternRewriteDriver::notifyOperationRemoved(op); + if (erasedOps) + erasedOps->insert(op); if (strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.erase(op); } @@ -610,13 +615,17 @@ /// depending on `strictMode`. This set is not maintained when `strictMode` /// is GreedyRewriteStrictness::AnyOp. llvm::SmallDenseSet strictModeFilteredOps; + + /// An optional set of ops that were erased. + DenseSet *erasedOps = nullptr; }; } // namespace -LogicalResult -MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, - bool *changed) { +LogicalResult MultiOpPatternRewriteDriver::simplifyLocally( + ArrayRef ops, bool *changed, DenseSet *erased) { + erasedOps = erased; + if (strictMode != GreedyRewriteStrictness::AnyOp) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); @@ -699,6 +708,7 @@ } } + erased = nullptr; return success(worklist.empty()); } @@ -721,15 +731,23 @@ LogicalResult mlir::applyOpPatternsAndFold( ArrayRef ops, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, bool *changed) { + GreedyRewriteStrictness strictMode, bool *changed, bool *allOpsErased) { if (ops.empty()) { if (changed) *changed = false; + if (allOpsErased) + *allOpsErased = true; return success(); } // Start the pattern driver. MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, strictMode); - return driver.simplifyLocally(ops, changed); + DenseSet erased; + LogicalResult converged = + driver.simplifyLocally(ops, changed, allOpsErased ? &erased : nullptr); + if (allOpsErased) + *allOpsErased = + llvm::all_of(ops, [&](Operation *op) { return erased.contains(op); }); + return converged; } diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -7,6 +7,7 @@ // RUN: --split-input-file %s | FileCheck %s --check-prefix=CHECK-EX // CHECK-EN-LABEL: func @test_erase +// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true} // CHECK-EN: test.arg0 // CHECK-EN: test.arg1 // CHECK-EN-NOT: test.erase_op @@ -20,6 +21,7 @@ // ----- // CHECK-EN-LABEL: func @test_insert_same_op +// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true} // CHECK-EN: "test.insert_same_op"() {skip = true} // CHECK-EN: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { @@ -30,6 +32,7 @@ // ----- // CHECK-EN-LABEL: func @test_replace_with_new_op +// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} // CHECK-EN: %[[n:.*]] = "test.new_op" // CHECK-EN: "test.dummy_user"(%[[n]]) // CHECK-EN: "test.dummy_user"(%[[n]]) @@ -43,10 +46,12 @@ // ----- // CHECK-EN-LABEL: func @test_replace_with_erase_op +// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} // CHECK-EN-NOT: test.replace_with_new_op // CHECK-EN-NOT: test.erase_op // CHECK-EX-LABEL: func @test_replace_with_erase_op +// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true} // CHECK-EX-NOT: test.replace_with_new_op // CHECK-EX: test.erase_op func.func @test_replace_with_erase_op() { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -254,8 +254,9 @@ } void runOnOperation() override { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + MLIRContext *ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + patterns.add(ctx); SmallVector ops; getOperation()->walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); @@ -279,7 +280,14 @@ // Check if these transformations introduce visiting of operations that // are not in the `ops` set (The new created ops are valid). An invalid // operation will trigger the assertion while processing. - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode); + bool changed = false; + bool allErased = false; + (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode, + &changed, &allErased); + Builder b(ctx); + getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); + getOperation()->setAttr("pattern_driver_all_erased", + b.getBoolAttr(allErased)); } Option strictMode{