diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -17,6 +17,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FoldInterfaces.h" namespace mlir { @@ -31,19 +32,14 @@ /// generated along the way. class OperationFolder { public: - OperationFolder(MLIRContext *ctx) : interfaces(ctx) {} + OperationFolder(MLIRContext *ctx, RewriterBase::Listener *listener = nullptr) + : interfaces(ctx), listener(listener) {} /// Tries to perform folding on the given `op`, including unifying /// deduplicated constants. If successful, replaces `op`'s uses with - /// folded results, and returns success. `preReplaceAction` is invoked on `op` - /// before it is replaced. 'processGeneratedConstants' is invoked for any new - /// operations generated when folding. If the op was completely folded it is + /// folded results, and returns success. If the op was completely folded it is /// erased. If it is just updated in place, `inPlaceUpdate` is set to true. - LogicalResult - tryToFold(Operation *op, - function_ref processGeneratedConstants = nullptr, - function_ref preReplaceAction = nullptr, - bool *inPlaceUpdate = nullptr); + LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr); /// Tries to fold a pre-existing constant operation. `constValue` represents /// the value of the constant, and can be optionally passed if the value is @@ -122,23 +118,23 @@ using ConstantMap = DenseMap, Operation *>; + /// Erase the given operation and notify the listener. + void eraseOp(Operation *op); + /// Returns true if the given operation is an already folded constant that is /// owned by this folder. bool isFolderOwnedConstant(Operation *op) const; /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. - LogicalResult tryToFold( - OpBuilder &builder, Operation *op, SmallVectorImpl &results, - function_ref processGeneratedConstants = nullptr); + LogicalResult tryToFold(OpBuilder &builder, Operation *op, + SmallVectorImpl &results); /// Try to process a set of fold results, generating constants as necessary. /// Populates `results` on success, otherwise leaves it unchanged. - LogicalResult - processFoldResults(OpBuilder &builder, Operation *op, - SmallVectorImpl &results, - ArrayRef foldResults, - function_ref processGeneratedConstants); + LogicalResult processFoldResults(OpBuilder &builder, Operation *op, + SmallVectorImpl &results, + ArrayRef foldResults); /// Try to get or create a new constant entry. On success this returns the /// constant operation, nullptr otherwise. @@ -156,6 +152,9 @@ /// A collection of dialect folder interfaces. DialectInterfaceCollection interfaces; + + /// An optional listener that is notified of all IR changes. + RewriterBase::Listener *listener = nullptr; }; } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -67,9 +67,7 @@ // OperationFolder //===----------------------------------------------------------------------===// -LogicalResult OperationFolder::tryToFold( - Operation *op, function_ref processGeneratedConstants, - function_ref preReplaceAction, bool *inPlaceUpdate) { +LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) { if (inPlaceUpdate) *inPlaceUpdate = false; @@ -86,27 +84,26 @@ // Try to fold the operation. SmallVector results; - OpBuilder builder(op); - if (failed(tryToFold(builder, op, results, processGeneratedConstants))) + OpBuilder builder(op, listener); + if (failed(tryToFold(builder, op, results))) return failure(); // Check to see if the operation was just updated in place. if (results.empty()) { if (inPlaceUpdate) *inPlaceUpdate = true; + if (listener) + listener->notifyOperationModified(op); return success(); } - // Constant folding succeeded. We will start replacing this op's uses and - // erase this op. Invoke the callback provided by the caller to perform any - // pre-replacement action. - if (preReplaceAction) - preReplaceAction(op); - - // Replace all of the result values and erase the operation. + // Constant folding succeeded. Replace all of the result values and erase the + // operation. + if (listener) + listener->notifyOperationReplaced(op, results); for (unsigned i = 0, e = results.size(); i != e; ++i) op->getResult(i).replaceAllUsesWith(results[i]); - op->erase(); + eraseOp(op); return success(); } @@ -144,8 +141,10 @@ // If there is an existing constant, replace `op`. if (folderConstOp) { + if (listener) + listener->notifyOperationReplaced(op, folderConstOp->getResults()); op->replaceAllUsesWith(folderConstOp); - op->erase(); + eraseOp(op); return false; } @@ -163,6 +162,13 @@ return true; } +void OperationFolder::eraseOp(Operation *op) { + notifyRemoval(op); + if (listener) + listener->notifyOperationRemoved(op); + op->erase(); +} + /// Notifies that the given constant `op` should be remove from this /// OperationFolder's internal bookkeeping. void OperationFolder::notifyRemoval(Operation *op) { @@ -221,9 +227,8 @@ /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. -LogicalResult OperationFolder::tryToFold( - OpBuilder &builder, Operation *op, SmallVectorImpl &results, - function_ref processGeneratedConstants) { +LogicalResult OperationFolder::tryToFold(OpBuilder &builder, Operation *op, + SmallVectorImpl &results) { SmallVector operandConstants; // If this is a commutative operation, move constants to be trailing operands. @@ -252,16 +257,15 @@ // fold. SmallVector foldResults; if (failed(op->fold(operandConstants, foldResults)) || - failed(processFoldResults(builder, op, results, foldResults, - processGeneratedConstants))) + failed(processFoldResults(builder, op, results, foldResults))) return success(updatedOpOperands); return success(); } -LogicalResult OperationFolder::processFoldResults( - OpBuilder &builder, Operation *op, SmallVectorImpl &results, - ArrayRef foldResults, - function_ref processGeneratedConstants) { +LogicalResult +OperationFolder::processFoldResults(OpBuilder &builder, Operation *op, + SmallVectorImpl &results, + ArrayRef foldResults) { // Check to see if the operation was just updated in place. if (foldResults.empty()) return success(); @@ -312,20 +316,13 @@ // If materialization fails, cleanup any operations generated for the // previous results and return failure. for (Operation &op : llvm::make_early_inc_range( - llvm::make_range(entry.begin(), builder.getInsertionPoint()))) { - notifyRemoval(&op); - op.erase(); - } + llvm::make_range(entry.begin(), builder.getInsertionPoint()))) + eraseOp(&op); + results.clear(); return failure(); } - // Process any newly generated operations. - if (processGeneratedConstants) { - for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i) - processGeneratedConstants(&*i); - } - return success(); } @@ -358,7 +355,7 @@ // If an existing operation in the new dialect already exists, delete the // materialized operation in favor of the existing one. if (auto *existingOp = uniquedConstants.lookup(newKey)) { - constOp->erase(); + eraseOp(constOp); referencedDialects[existingOp].push_back(dialect); return constOp = existingOp; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -127,7 +127,8 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) - : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) { + : PatternRewriter(ctx), folder(ctx, this), config(config), + matcher(patterns) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -156,9 +157,6 @@ }; #endif - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands; - bool changed = false; int64_t numRewrites = 0; while (!worklist.empty() && @@ -197,34 +195,11 @@ continue; } - // Collects all the operands and result uses of the given `op` into work - // list. Also remove `op` and nested ops from worklist. - originalOperands.assign(op->operand_begin(), op->operand_end()); - auto preReplaceAction = [&](Operation *op) { - // Add the operands to the worklist for visitation. - addOperandsToWorklist(originalOperands); - - // Add all the users of the result to the worklist so we make sure - // to revisit them. - for (auto result : op->getResults()) - for (auto *userOp : result.getUsers()) - addToWorklist(userOp); - - notifyOperationRemoved(op); - }; - - // Add the given operation to the worklist. - auto collectOps = [this](Operation *op) { addToWorklist(op); }; - // Try to fold this op. - bool inPlaceUpdate; - if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, - &inPlaceUpdate)))) { + if (succeeded(folder.tryToFold(op))) { LLVM_DEBUG(logResultWithLine("success", "operation was folded")); - changed = true; - if (!inPlaceUpdate) - continue; + continue; } // Try to match one of the patterns. The rewriter is automatically @@ -465,7 +440,7 @@ // Add all nested operations to the worklist in preorder. region.walk([&](Operation *op) { if (!insertKnownConstant(op)) { - worklist.push_back(op); + addToWorklist(op); return WalkResult::advance(); } return WalkResult::skip(); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1116,7 +1116,8 @@ } OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { - if (adaptor.getOp()) { + if (adaptor.getOp() && !(*this)->hasAttr("attr")) { + // The folder adds "attr" if not present. (*this)->setAttr("attr", adaptor.getOp()); return getResult(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1280,7 +1280,7 @@ } def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> { - let arguments = (ins I32:$op, I32Attr:$attr); + let arguments = (ins I32:$op, OptionalAttr:$attr); let results = (outs I32); let hasFolder = 1; } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -93,8 +93,7 @@ // (unchanged) operation result. OperationFolder folder(op->getContext()); Value result = folder.create( - rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), - rewriter.getI32IntegerAttr(0)); + rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); assert(result); rewriter.replaceOp(op, result); return success(); diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp --- a/mlir/test/lib/Transforms/TestConstantFold.cpp +++ b/mlir/test/lib/Transforms/TestConstantFold.cpp @@ -13,8 +13,8 @@ namespace { /// Simple constant folding pass. -struct TestConstantFold - : public PassWrapper> { +struct TestConstantFold : public PassWrapper>, + public RewriterBase::Listener { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConstantFold) StringRef getArgument() const final { return "test-constant-fold"; } @@ -26,17 +26,22 @@ void foldOperation(Operation *op, OperationFolder &helper); void runOnOperation() override; + + void notifyOperationInserted(Operation *op) override { + existingConstants.push_back(op); + } + void notifyOperationRemoved(Operation *op) override { + auto it = llvm::find(existingConstants, op); + if (it != existingConstants.end()) + existingConstants.erase(it); + } }; } // namespace void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) { - auto processGeneratedConstants = [this](Operation *op) { - existingConstants.push_back(op); - }; - // Attempt to fold the specified operation, including handling unused or // duplicated constants. - (void)helper.tryToFold(op, processGeneratedConstants); + (void)helper.tryToFold(op); } void TestConstantFold::runOnOperation() { @@ -50,7 +55,7 @@ // folding are at the beginning. This creates somewhat of a linear ordering to // the newly generated constants that matches the operation order and improves // the readability of test cases. - OperationFolder helper(&getContext()); + OperationFolder helper(&getContext(), /*listener=*/this); for (Operation *op : llvm::reverse(ops)) foldOperation(op, helper);