diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -525,7 +525,9 @@ SmallVector bools(xferOp.getTransferRank(), true); auto inBoundsAttr = b.getBoolArrayAttr(bools); if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { - xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); + b.updateRootInPlace(xferOp, [&]() { + xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); + }); return success(); } @@ -596,7 +598,9 @@ for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); - xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); + b.updateRootInPlace(xferOp, [&]() { + xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); + }); return success(); } @@ -623,7 +627,7 @@ else createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); - xferOp->erase(); + b.eraseOp(xferOp); return success(); } @@ -634,11 +638,5 @@ if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || failed(filter(xferOp))) return failure(); - rewriter.startRootUpdate(xferOp); - if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { - rewriter.finalizeRootUpdate(xferOp); - return success(); - } - rewriter.cancelRootUpdate(xferOp); - return failure(); + return splitFullAndPartialTransfer(rewriter, xferOp, options); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -51,6 +51,10 @@ /// If the specified operation is in the worklist, remove it. void removeFromWorklist(Operation *op); + /// Notifies the driver that the specified operation may have been modified + /// in-place. + void finalizeRootUpdate(Operation *op) override; + protected: // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. @@ -326,6 +330,14 @@ addToWorklist(op); } +void GreedyPatternRewriteDriver::finalizeRootUpdate(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Modified: '" << op->getName() << "'(" << op + << ")\n"; + }); + addToWorklist(op); +} + void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) { for (Value operand : operands) { // If the use count of this operand is now < 2, we re-add the defining diff --git a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s -test-patterns="max-iterations=1" | FileCheck %s + +// CHECK-LABEL: func @add_to_worklist_after_inplace_update() +func.func @add_to_worklist_after_inplace_update() { + // The following op is updated in-place and should be added back to the + // worklist of the GreedyPatternRewriteDriver (regardless of the value of + // config.max_iterations). + + // CHECK: "test.any_attr_of_i32_str"() {attr = 3 : i32} : () -> () + "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -147,6 +147,26 @@ } }; +/// This pattern matches test.any_attr_of_i32_str ops. In case of an integer +/// attribute with value smaller than MaxVal, it increments the value by 1. +template +struct IncrementIntAttribute : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AnyAttrOfOp op, + PatternRewriter &rewriter) const override { + auto intAttr = op.getAttr().dyn_cast(); + if (!intAttr) + return failure(); + int64_t val = intAttr.getInt(); + if (val >= MaxVal) + return failure(); + rewriter.updateRootInPlace( + op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); + return success(); + } +}; + struct TestPatternDriver : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) @@ -165,8 +185,12 @@ FolderInsertBeforePreviouslyFoldedConstantPattern, FolderCommutativeOp2WithConstant>(&getContext()); + // Additional patterns for testing the GreedyPatternRewriteDriver. + patterns.insert>(&getContext()); + GreedyRewriteConfig config; config.useTopDownTraversal = this->useTopDownTraversal; + config.maxIterations = this->maxIterations; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config); } @@ -175,6 +199,10 @@ *this, "top-down", llvm::cl::desc("Seed the worklist in general top-down order"), llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; + Option maxIterations{ + *this, "max-iterations", + llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), + llvm::cl::init(GreedyRewriteConfig().maxIterations)}; }; struct TestStrictPatternDriver