diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -525,7 +525,9 @@ SmallVector bools(xferOp.getTransferRank(), true); auto inBoundsAttr = b.getBoolArrayAttr(bools); if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { - xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); + b.updateRootInPlace(xferOp, [&]() { + xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); + }); return success(); } @@ -596,7 +598,9 @@ for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); - xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); + b.updateRootInPlace(xferOp, [&]() { + xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); + }); return success(); } @@ -623,7 +627,7 @@ else createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); - xferOp->erase(); + b.eraseOp(xferOp); return success(); } @@ -634,11 +638,5 @@ if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || failed(filter(xferOp))) return failure(); - rewriter.startRootUpdate(xferOp); - if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { - rewriter.finalizeRootUpdate(xferOp); - return success(); - } - rewriter.cancelRootUpdate(xferOp); - return failure(); + return splitFullAndPartialTransfer(rewriter, xferOp, options); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -51,6 +51,10 @@ /// If the specified operation is in the worklist, remove it. void removeFromWorklist(Operation *op); + /// Notifies the driver that the specified operation may have been modified + /// in-place. + void finalizeRootUpdate(Operation *op) override; + protected: // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. @@ -326,6 +330,14 @@ addToWorklist(op); } +void GreedyPatternRewriteDriver::finalizeRootUpdate(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Modified: '" << op->getName() << "'(" << op + << ")\n"; + }); + addToWorklist(op); +} + void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) { for (Value operand : operands) { // If the use count of this operand is now < 2, we re-add the defining diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt --- a/mlir/unittests/Transforms/CMakeLists.txt +++ b/mlir/unittests/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRTransformsTests Canonicalizer.cpp DialectConversion.cpp + GreedyPatternRewriteDriver.cpp ) target_link_libraries(MLIRTransformsTests PRIVATE diff --git a/mlir/unittests/Transforms/GreedyPatternRewriteDriver.cpp b/mlir/unittests/Transforms/GreedyPatternRewriteDriver.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Transforms/GreedyPatternRewriteDriver.cpp @@ -0,0 +1,66 @@ +//===- GreedyPatternRewriteDriver.cpp - Greedy pattern rewriter tests -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/IR/Builders.h" +#include "gtest/gtest.h" + +#include "../../test/lib/Dialect/Test/TestDialect.h" + +using namespace mlir; + +namespace { +template +struct RewriteAttribute : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::AnyAttrOfOp op, + PatternRewriter &rewriter) const override { + int64_t constVal = op.getAttr().cast().getInt(); + if (constVal != FromCst) + return failure(); + rewriter.updateRootInPlace( + op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(ToCst)); }); + return success(); + } +}; + +TEST(GreedyPatternRewriteDriver, EnqueueUpdatedOp) { + MLIRContext context; + context.allowUnregisteredDialects(); + context.loadDialect(); + + OpBuilder builder(&context); + Location loc = builder.getUnknownLoc(); + + // Build an op with a region. IsolatedRegionOp has an operand but it is not + // used in this test. + test::TestOpConstant dummyIndex = builder.create( + loc, builder.getIndexType(), builder.getIndexAttr(0)); + auto regionOp = builder.create(loc, dummyIndex); + Block &block = regionOp.getRegion().emplaceBlock(); + block.addArgument(builder.getIndexType(), loc); + builder.setInsertionPointToStart(&block); + + // Add a test op to the body. + test::AnyAttrOfOp testOp = + builder.create(loc, builder.getI32IntegerAttr(0)); + builder.create(loc); + + // Patterns change the attribute of testOp: 0 -> 1 -> 2 -> 3 + RewritePatternSet patterns(&context); + patterns.insert, RewriteAttribute<1, 2>, + RewriteAttribute<2, 3>>(&context); + GreedyRewriteConfig config; + config.maxIterations = 1; + LogicalResult status = + applyPatternsAndFoldGreedily(regionOp, std::move(patterns), config); + EXPECT_TRUE(status.succeeded()); + EXPECT_EQ(testOp.getAttr().cast().getInt(), 3); +} +} // namespace diff --git a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel @@ -281,6 +281,7 @@ "//mlir:Pass", "//mlir:TransformUtils", "//mlir:Transforms", + "//mlir/test:TestDialect", "//third-party/unittest:gtest_main", ], )