diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -33,11 +33,6 @@ public: OperationFolder(MLIRContext *ctx) : interfaces(ctx) {} - /// Scan the specified region for constants that can be used in folding, - /// moving them to the entry block (or any custom insertion location specified - /// by shouldMaterializeInto), and add them to our known-constants table. - void processExistingConstants(Region ®ion); - /// Tries to perform folding on the given `op`, including unifying /// deduplicated constants. If successful, replaces `op`'s uses with /// folded results, and returns success. `preReplaceAction` is invoked on `op` diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -18,6 +18,24 @@ namespace mlir { +/// This struct allows control over how the GreedyPatternRewriteDriver works. +struct GreedyRewriteConfig { + /// This specifies the order of initial traversal that populates the rewriters + /// worklist. When set to true, it walks the operations top-down, which is + /// generally more efficient in compile time. When set to false, its initial + /// traversal of the region tree is bottom up on each block, which may match + /// larger patterns when given an ambiguous pattern set. + bool useTopDownTraversal = false; + + // Perform control flow optimizations to the region tree after applying all + // patterns. + bool enableRegionSimplification = true; + + /// This specifies the maximum number of times the rewriter will iterate + /// between applying patterns and simplifying regions. + unsigned maxIterations = 10; +}; + //===----------------------------------------------------------------------===// // applyPatternsGreedily //===----------------------------------------------------------------------===// @@ -37,33 +55,17 @@ /// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// -/// You may choose the order of initial traversal with the `useTopDownTraversal` -/// boolean. When set to true, it walks the operations top-down, which is -/// generally more efficient in compile time. When set to false, its initial -/// traversal of the region tree is post-order, which may match larger patterns -/// when given an ambiguous pattern set. -LogicalResult -applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns, - bool useTopDownTraversal = false); - -/// Rewrite the regions of the specified operation, with a user-provided limit -/// on iterations to attempt before reaching convergence. +/// You may configure several aspects of this with GreedyRewriteConfig. LogicalResult applyPatternsAndFoldGreedily( - Operation *op, const FrozenRewritePatternSet &patterns, - unsigned maxIterations, bool useTopDownTraversal = false); + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig()); /// Rewrite the given regions, which must be isolated from above. -LogicalResult -applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns, - bool useTopDownTraversal = false); - -/// Rewrite the given regions, with a user-provided limit on iterations to -/// attempt before reaching convergence. -LogicalResult applyPatternsAndFoldGreedily( - MutableArrayRef regions, const FrozenRewritePatternSet &patterns, - unsigned maxIterations, bool useTopDownTraversal = false); +inline LogicalResult applyPatternsAndFoldGreedily( + Operation *op, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig()) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config); +} /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -31,10 +31,10 @@ return success(); } void runOnOperation() override { - (void)applyPatternsAndFoldGreedily( - getOperation()->getRegions(), patterns, - /*maxIterations=*/10, /*useTopDownTraversal=*/ - topDownProcessingEnabled); + GreedyRewriteConfig config; + config.useTopDownTraversal = topDownProcessingEnabled; + (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns, + config); } FrozenRewritePatternSet patterns; diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -84,85 +84,6 @@ // OperationFolder //===----------------------------------------------------------------------===// -/// Scan the specified region for constants that can be used in folding, -/// moving them to the entry block (or any custom insertion location specified -/// by shouldMaterializeInto), and add them to our known-constants table. -void OperationFolder::processExistingConstants(Region ®ion) { - if (region.empty()) - return; - - // March the constant insertion point forward, moving all constants to the - // top of the block, but keeping them in their order of discovery. - Region *insertRegion = getInsertionRegion(interfaces, ®ion.front()); - auto &uniquedConstants = foldScopes[insertRegion]; - - Block &insertBlock = insertRegion->front(); - Block::iterator constantIterator = insertBlock.begin(); - - // Process each constant that we discover in this region. - auto processConstant = [&](Operation *op, Attribute value) { - assert(op->getNumResults() == 1 && "constants have one result"); - // Check to see if we already have an instance of this constant. - Operation *&constOp = uniquedConstants[std::make_tuple( - op->getDialect(), value, op->getResult(0).getType())]; - - // If we already have an instance of this constant, CSE/delete this one as - // we go. - if (constOp) { - if (constantIterator == Block::iterator(op)) - ++constantIterator; // Don't invalidate our iterator when scanning. - op->getResult(0).replaceAllUsesWith(constOp->getResult(0)); - op->erase(); - return; - } - - // Otherwise, remember that we have this constant. - constOp = op; - referencedDialects[op].push_back(op->getDialect()); - - // If the constant isn't already at the insertion point then move it up. - if (constantIterator != Block::iterator(op)) - op->moveBefore(&insertBlock, constantIterator); - else - ++constantIterator; // It was pointing at the constant. - }; - - // Collect all the constants for this region of isolation or insertion (as - // specified by the shouldMaterializeInto hook). Collect any subregions of - // isolation/constant insertion for subsequent processing. - SmallVector insertionSubregionOps; - region.walk([&](Operation *op) { - // If this is a constant, process it. - Attribute value; - if (matchPattern(op, m_Constant(&value))) { - processConstant(op, value); - // We may have deleted the operation, don't check it for regions. - return WalkResult::skip(); - } - - // If the operation has regions and is isolated, don't recurse into it. - if (op->getNumRegions() != 0) { - auto hasDifferentInsertRegion = [&](Region ®ion) { - return !region.empty() && - getInsertionRegion(interfaces, ®ion.front()) != insertRegion; - }; - if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) { - insertionSubregionOps.push_back(op); - return WalkResult::skip(); - } - } - - // Otherwise keep going. - return WalkResult::advance(); - }); - - // Process regions in any isolated ops separately. - for (Operation *subregionOps : insertionSubregionOps) { - for (Region ®ion : subregionOps->getRegions()) - processExistingConstants(region); - } -} - LogicalResult OperationFolder::tryToFold( Operation *op, function_ref processGeneratedConstants, function_ref preReplaceAction, bool *inPlaceUpdate) { diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -24,9 +24,6 @@ #define DEBUG_TYPE "pattern-matcher" -/// The max number of iterations scanning for pattern match. -static unsigned maxPatternMatchIterations = 10; - //===----------------------------------------------------------------------===// // GreedyPatternRewriteDriver //===----------------------------------------------------------------------===// @@ -38,16 +35,15 @@ public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - bool useTopDownTraversal) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), - useTopDownTraversal(useTopDownTraversal) { + const GreedyRewriteConfig &config) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); } - bool simplify(MutableArrayRef regions, int maxIterations); + bool simplify(MutableArrayRef regions); void addToWorklist(Operation *op) { // Check to see if the worklist already contains this op. @@ -137,40 +133,30 @@ /// Non-pattern based folder for operations. OperationFolder folder; - /// Whether to use a top-down or bottom-up traversal to seed the initial - /// worklist. - bool useTopDownTraversal; + /// Configuration information for how to simplify. + GreedyRewriteConfig config; }; } // end anonymous namespace /// Performs the rewrites while folding and erasing any dead ops. Returns true /// if the rewrite converges in `maxIterations`. -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, - int maxIterations) { - // For maximum compatibility with existing passes, do not process existing - // constants unless we're performing a top-down traversal. - // TODO: This is just for compatibility with older MLIR, remove this. - if (useTopDownTraversal) { - // Perform a prepass over the IR to discover constants. - for (auto ®ion : regions) - folder.processExistingConstants(region); - } - +bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { bool changed = false; - int iteration = 0; + unsigned iteration = 0; do { worklist.clear(); worklistMap.clear(); - // Add all nested operations to the worklist in preorder. - for (auto ®ion : regions) - if (useTopDownTraversal) + if (!config.useTopDownTraversal) { + // Add operations to the worklist in postorder. + for (auto ®ion : regions) + region.walk([this](Operation *op) { addToWorklist(op); }); + } else { + // Add all nested operations to the worklist in preorder. + for (auto ®ion : regions) region.walk( [this](Operation *op) { worklist.push_back(op); }); - else - region.walk([this](Operation *op) { addToWorklist(op); }); - if (useTopDownTraversal) { // Reverse the list so our pop-back loop processes them in-order. std::reverse(worklist.begin(), worklist.end()); // Remember the reverse index. @@ -234,8 +220,9 @@ // After applying patterns, make sure that the CFG of each of the regions // is kept up to date. - changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed && ++iteration < maxIterations); + if (config.enableRegionSimplification) + changed |= succeeded(simplifyRegions(*this, regions)); + } while (changed && ++iteration < config.maxIterations); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; @@ -248,29 +235,9 @@ /// top-level operation itself. /// LogicalResult -mlir::applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns, - bool useTopDownTraversal) { - return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations, - useTopDownTraversal); -} -LogicalResult mlir::applyPatternsAndFoldGreedily( - Operation *op, const FrozenRewritePatternSet &patterns, - unsigned maxIterations, bool useTopDownTraversal) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations, - useTopDownTraversal); -} -/// Rewrite the given regions, which must be isolated from above. -LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternSet &patterns, - bool useTopDownTraversal) { - return applyPatternsAndFoldGreedily( - regions, patterns, maxPatternMatchIterations, useTopDownTraversal); -} -LogicalResult mlir::applyPatternsAndFoldGreedily( - MutableArrayRef regions, const FrozenRewritePatternSet &patterns, - unsigned maxIterations, bool useTopDownTraversal) { + GreedyRewriteConfig config) { if (regions.empty()) return success(); @@ -285,12 +252,11 @@ "patterns can only be applied to operations IsolatedFromAbove"); // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, - useTopDownTraversal); - bool converged = driver.simplify(regions, maxIterations); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); + bool converged = driver.simplify(regions); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " - << maxIterations << " times\n"; + << config.maxIterations << " times\n"; }); return success(converged); } @@ -391,15 +357,16 @@ LogicalResult mlir::applyOpPatternsAndFold( Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { // Start the pattern driver. + GreedyRewriteConfig config; OpPatternRewriteDriver driver(op->getContext(), patterns); bool opErased; LogicalResult converged = - driver.simplifyLocally(op, maxPatternMatchIterations, opErased); + driver.simplifyLocally(op, config.maxIterations, opErased); if (erased) *erased = opErased; LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " - << maxPatternMatchIterations << " times"; + << config.maxIterations << " times"; }); return converged; }