diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -56,11 +56,12 @@ /// folded results, and returns success. `preReplaceAction` is invoked on `op` /// before it is replaced. 'processGeneratedConstants' is invoked for any new /// operations generated when folding. If the op was completely folded it is - /// erased. + /// erased. If it is just updated in place, `inPlaceUpdated` is set to true. LogicalResult tryToFold(Operation *op, function_ref processGeneratedConstants = nullptr, - function_ref preReplaceAction = nullptr); + function_ref preReplaceAction = nullptr, + bool *inPlaceUpdate = nullptr); /// Notifies that the given constant `op` should be remove from this /// OperationFolder's internal bookkeeping. diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -74,7 +74,10 @@ LogicalResult OperationFolder::tryToFold( Operation *op, function_ref processGeneratedConstants, - function_ref preReplaceAction) { + function_ref preReplaceAction, bool *inPlaceUpdate) { + if (inPlaceUpdate) + *inPlaceUpdate = false; + // If this is a unique'd constant, return failure as we know that it has // already been folded. if (referencedDialects.count(op)) @@ -87,8 +90,11 @@ return failure(); // Check to see if the operation was just updated in place. - if (results.empty()) + if (results.empty()) { + if (inPlaceUpdate) + *inPlaceUpdate = true; return success(); + } // Constant folding succeeded. We will start replacing this op's uses and // erase this op. Invoke the callback provided by the caller to perform any diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -107,7 +107,8 @@ // be re-added to the worklist. This function should be called when an // operation is modified or removed, as it may trigger further // simplifications. - template void addToWorklist(Operands &&operands) { + template + void addToWorklist(Operands &&operands) { for (Value operand : operands) { // If the use count of this operand is now < 2, we re-add the defining // operation to the worklist. @@ -136,7 +137,8 @@ }; } // end anonymous namespace -/// Perform the rewrites while folding and erasing any dead ops. +/// Performs the rewrites while folding and erasing any dead ops. Returns true +/// if the rewrite converges in `maxIterations`. bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, int maxIterations) { // Add the given operation to the worklist. @@ -186,9 +188,12 @@ }; // Try to fold this op. - if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) { + bool inPlaceUpdate; + if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, + &inPlaceUpdate)))) { changed = true; - continue; + if (!inPlaceUpdate) + continue; } // Make sure that any new operations are inserted at this point.