diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -558,9 +558,9 @@ private: void notifyOperationInserted(Operation *op) override { - GreedyPatternRewriteDriver::notifyOperationInserted(op); if (strictMode) strictModeFilteredOps.insert(op); + GreedyPatternRewriteDriver::notifyOperationInserted(op); } void notifyOperationRemoved(Operation *op) override { diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir --- a/mlir/test/Transforms/test-strict-pattern-driver.mlir +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -1,6 +1,9 @@ // RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s -// CHECK-LABEL: @test_erase +// CHECK-LABEL: func @test_erase +// CHECK: test.arg0 +// CHECK: test.arg1 +// CHECK-NOT: test.erase_op func.func @test_erase() { %0 = "test.arg0"() : () -> (i32) %1 = "test.arg1"() : () -> (i32) @@ -8,16 +11,29 @@ return } -// CHECK-LABEL: @test_insert_same_op +// CHECK-LABEL: func @test_insert_same_op +// CHECK: "test.insert_same_op"() {skip = true} +// CHECK: "test.insert_same_op"() {skip = true} func.func @test_insert_same_op() { %0 = "test.insert_same_op"() : () -> (i32) return } -// CHECK-LABEL: @test_replace_with_same_op -func.func @test_replace_with_same_op() { - %0 = "test.replace_with_same_op"() : () -> (i32) +// CHECK-LABEL: func @test_replace_with_new_op +// CHECK: %[[n:.*]] = "test.new_op" +// CHECK: "test.dummy_user"(%[[n]]) +// CHECK: "test.dummy_user"(%[[n]]) +func.func @test_replace_with_new_op() { + %0 = "test.replace_with_new_op"() : () -> (i32) %1 = "test.dummy_user"(%0) : (i32) -> (i32) %2 = "test.dummy_user"(%0) : (i32) -> (i32) return } + +// CHECK-LABEL: func @test_replace_with_erase_op +// CHECK-NOT: test.replace_with_new_op +// CHECK-NOT: test.erase_op +func.func @test_replace_with_erase_op() { + "test.replace_with_new_op"() {create_erase_op} : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -220,12 +220,12 @@ void runOnOperation() override { mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); SmallVector ops; getOperation()->walk([&](Operation *op) { StringRef opName = op->getName().getStringRef(); if (opName == "test.insert_same_op" || - opName == "test.replace_with_same_op" || opName == "test.erase_op") { + opName == "test.replace_with_new_op" || opName == "test.erase_op") { ops.push_back(op); } }); @@ -260,16 +260,25 @@ }; // Replace an operation may introduce the re-visiting of its users. - class ReplaceWithSameOp : public RewritePattern { + class ReplaceWithNewOp : public RewritePattern { public: - ReplaceWithSameOp(MLIRContext *context) - : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {} + ReplaceWithNewOp(MLIRContext *context) + : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - Operation *newOp = - rewriter.create(op->getLoc(), op->getName().getIdentifier(), - op->getOperands(), op->getResultTypes()); + Operation *newOp; + if (op->hasAttr("create_erase_op")) { + newOp = rewriter.create( + op->getLoc(), + OperationName("test.erase_op", op->getContext()).getIdentifier(), + ValueRange(), TypeRange()); + } else { + newOp = rewriter.create( + op->getLoc(), + OperationName("test.new_op", op->getContext()).getIdentifier(), + op->getOperands(), op->getResultTypes()); + } rewriter.replaceOp(op, newOp->getResults()); return success(); }