diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -43,7 +43,7 @@ bool simplify(MutableArrayRef regions); /// Add the given operation to the worklist. - void addToWorklist(Operation *op); + virtual void addToWorklist(Operation *op); /// Pop the next operation from the worklist. Operation *popFromWorklist(); @@ -60,8 +60,7 @@ // be re-added to the worklist. This function should be called when an // operation is modified or removed, as it may trigger further // simplifications. - template - void addToWorklist(Operands &&operands); + void addOperandsToWorklist(ValueRange operands); // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. @@ -219,7 +218,7 @@ originalOperands.assign(op->operand_begin(), op->operand_end()); auto preReplaceAction = [&](Operation *op) { // Add the operands to the worklist for visitation. - addToWorklist(originalOperands); + addOperandsToWorklist(originalOperands); // Add all the users of the result to the worklist so we make sure // to revisit them. @@ -327,8 +326,7 @@ addToWorklist(op); } -template -void GreedyPatternRewriteDriver::addToWorklist(Operands &&operands) { +void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) { for (Value operand : operands) { // If the use count of this operand is now < 2, we re-add the defining // operation to the worklist. @@ -343,7 +341,7 @@ } void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { - addToWorklist(op->getOperands()); + addOperandsToWorklist(op->getOperands()); op->walk([this](Operation *operation) { removeFromWorklist(operation); folder.notifyRemoval(operation); @@ -523,22 +521,12 @@ bool simplifyLocally(ArrayRef op); -private: - // Look over the provided operands for any defining operations that should - // be re-added to the worklist. This function should be called when an - // operation is modified or removed, as it may trigger further - // simplifications. If `strict` is set to true, only ops in - // `strictModeFilteredOps` are considered. - template - void addOperandsToWorklist(Operands &&operands) { - for (Value operand : operands) { - if (auto *defOp = operand.getDefiningOp()) { - if (!strictMode || strictModeFilteredOps.contains(defOp)) - addToWorklist(defOp); - } - } + void addToWorklist(Operation *op) override { + if (!strictMode || strictModeFilteredOps.contains(op)) + GreedyPatternRewriteDriver::addToWorklist(op); } +private: void notifyOperationInserted(Operation *op) override { GreedyPatternRewriteDriver::notifyOperationInserted(op); if (strictMode) @@ -551,15 +539,6 @@ strictModeFilteredOps.erase(op); } - void notifyRootReplaced(Operation *op) override { - for (auto result : op->getResults()) { - for (auto *user : result.getUsers()) { - if (!strictMode || strictModeFilteredOps.contains(user)) - addToWorklist(user); - } - } - } - /// If `strictMode` is true, any pre-existing ops outside of /// `strictModeFilteredOps` remain completely untouched by the rewrite driver. /// If `strictMode` is false, operations that use results of (or supply @@ -633,22 +612,17 @@ // Add all the users of the result to the worklist so we make sure // to revisit them. - for (Value result : op->getResults()) - for (Operation *userOp : result.getUsers()) { - if (!strictMode || strictModeFilteredOps.contains(userOp)) - addToWorklist(userOp); - } + for (Value result : op->getResults()) { + for (Operation *userOp : result.getUsers()) + addToWorklist(userOp); + } + notifyOperationRemoved(op); }; // Add the given operation generated by the folder to the worklist. auto processGeneratedConstants = [this](Operation *op) { - // Newly created ops are also simplified -- these are also "local". - addToWorklist(op); - // When strict mode is off, we don't need to maintain - // strictModeFilteredOps. - if (strictMode) - strictModeFilteredOps.insert(op); + notifyOperationInserted(op); }; // Try to fold this op. diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s + +// CHECK-LABEL: @test_erase +func.func @test_erase() { + %0 = "test.arg0"() : () -> (i32) + %1 = "test.arg1"() : () -> (i32) + %erase = "test.erase_op"(%0, %1) : (i32, i32) -> (i32) + return +} + +// CHECK-LABEL: @test_insert_same_op +func.func @test_insert_same_op() { + %0 = "test.insert_same_op"() : () -> (i32) + return +} + +// CHECK-LABEL: @test_replace_with_same_op +func.func @test_replace_with_same_op() { + %0 = "test.replace_with_same_op"() : () -> (i32) + %1 = "test.dummy_user"(%0) : (i32) -> (i32) + %2 = "test.dummy_user"(%0) : (i32) -> (i32) + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -176,6 +176,91 @@ llvm::cl::desc("Seed the worklist in general top-down order"), llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; }; + +struct TestStrictPatternDriver + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) + + TestStrictPatternDriver() = default; + TestStrictPatternDriver(const TestStrictPatternDriver &other) + : PassWrapper(other) {} + + StringRef getArgument() const final { return "test-strict-pattern-driver"; } + StringRef getDescription() const final { + return "Run strict mode of pattern driver"; + } + + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + SmallVector ops; + getOperation()->walk([&](Operation *op) { + StringRef opName = op->getName().getStringRef(); + if (opName == "test.insert_same_op" || + opName == "test.replace_with_same_op" || opName == "test.erase_op") { + ops.push_back(op); + } + }); + + // Check if these transformations introduce visiting of operations that + // are not in the `ops` set (The new created ops are valid). An invalid + // operation will trigger the assertion while processing. + (void)applyOpPatternsAndFold(makeArrayRef(ops), std::move(patterns), + /*strict=*/true); + } + +private: + // New inserted operation is valid for further transformation. + class InsertSameOp : public RewritePattern { + public: + InsertSameOp(MLIRContext *context) + : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->hasAttr("skip")) + return failure(); + + Operation *newOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), + op->getOperands(), op->getResultTypes()); + op->setAttr("skip", rewriter.getBoolAttr(true)); + newOp->setAttr("skip", rewriter.getBoolAttr(true)); + + return success(); + } + }; + + // Replace an operation may introduce the re-visiting of its users. + class ReplaceWithSameOp : public RewritePattern { + public: + ReplaceWithSameOp(MLIRContext *context) + : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + Operation *newOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), + op->getOperands(), op->getResultTypes()); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } + }; + + // Remove an operation may introduce the re-visiting of its opreands. + class EraseOp : public RewritePattern { + public: + EraseOp(MLIRContext *context) + : RewritePattern("test.erase_op", /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } + }; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1471,6 +1556,7 @@ PassRegistration(); PassRegistration(); + PassRegistration(); PassRegistration([] { return std::make_unique(legalizerConversionMode);