diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -34,59 +34,44 @@ namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly -/// applies the locally optimal patterns in a roughly "bottom up" way. +/// applies the locally optimal patterns. +/// +/// This abstract class manages the worklist and contains helper methods for +/// rewriting ops on the worklist. Derived classes specify how ops are added +/// to the worklist in the beginning. class GreedyPatternRewriteDriver : public PatternRewriter { -public: +protected: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config); - /// Simplify the ops within the given region. - bool simplify(Region ®ion) &&; + /// Add the given operation to the worklist. + void addSingleOpToWorklist(Operation *op); /// Add the given operation and its ancestors to the worklist. void addToWorklist(Operation *op); - /// Pop the next operation from the worklist. - Operation *popFromWorklist(); - - /// If the specified operation is in the worklist, remove it. - void removeFromWorklist(Operation *op); - - /// Notifies the driver that the specified operation may have been modified - /// in-place. + /// Notify the driver that the specified operation may have been modified + /// in-place. The operation is added to the worklist. void finalizeRootUpdate(Operation *op) override; -protected: - /// Add the given operation to the worklist. - void addSingleOpToWorklist(Operation *op); - - // Implement the hook for inserting operations, and make sure that newly - // inserted ops are added to the worklist for processing. + /// Notify the driver that the specified operation was inserted. Update the + /// worklist as needed: The operation is enqueued depending on scope and + /// strict mode. void notifyOperationInserted(Operation *op) override; - // Look over the provided operands for any defining operations that should - // be re-added to the worklist. This function should be called when an - // operation is modified or removed, as it may trigger further - // simplifications. - void addOperandsToWorklist(ValueRange operands); - - // If an operation is about to be removed, make sure it is not in our - // worklist anymore because we'd get dangling references to it. + /// Notify the driver that the specified operation was removed. Update the + /// worklist as needed: The operation and its children are removed from the + /// worklist. void notifyOperationRemoved(Operation *op) override; - // When the root of a pattern is about to be replaced, it can trigger - // simplifications to its users - make sure to add them to the worklist - // before the root is changed. + /// Notify the driver that the specified operation was replaced. Update the + /// worklist as needed: New users are added enqueued. void notifyRootReplaced(Operation *op, ValueRange replacement) override; - /// PatternRewriter hook for notifying match failure reasons. - LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) override; - - /// The low-level pattern applicator. - PatternApplicator matcher; + /// Process ops until the worklist is empty or `config.maxNumRewrites` is + /// reached. Return `true` if any IR was changed. + bool processWorklist(); /// The worklist for this transformation keeps track of the operations that /// need to be revisited, plus their index in the worklist. This allows us to @@ -98,7 +83,6 @@ /// Non-pattern based folder for operations. OperationFolder folder; -protected: /// Configuration information for how to simplify. const GreedyRewriteConfig config; @@ -109,17 +93,37 @@ llvm::SmallDenseSet strictModeFilteredOps; private: + /// Look over the provided operands for any defining operations that should + /// be re-added to the worklist. This function should be called when an + /// operation is modified or removed, as it may trigger further + /// simplifications. + void addOperandsToWorklist(ValueRange operands); + + /// Pop the next operation from the worklist. + Operation *popFromWorklist(); + + /// For debugging only: Notify the driver of a pattern match failure. + LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback) override; + + /// If the specified operation is in the worklist, remove it. + void removeFromWorklist(Operation *op); + #ifndef NDEBUG /// A logger used to emit information during the application process. llvm::ScopedPrinter logger{llvm::dbgs()}; #endif + + /// The low-level pattern applicator. + PatternApplicator matcher; }; } // namespace GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) { assert(config.scope && "scope is not specified"); worklist.reserve(64); @@ -127,7 +131,7 @@ matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(Region ®ion) && { +bool GreedyPatternRewriteDriver::processWorklist() { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -146,130 +150,80 @@ }; #endif - auto insertKnownConstant = [&](Operation *op) { - // Check for existing constants when populating the worklist. This avoids - // accidentally reversing the constant order during processing. - Attribute constValue; - if (matchPattern(op, m_Constant(&constValue))) - if (!folder.insertKnownConstant(op, constValue)) - return true; - return false; - }; - - // Populate strict mode ops. - if (config.strictMode != GreedyRewriteStrictness::AnyOp) { - strictModeFilteredOps.clear(); - region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); - } + // These are scratch vectors used in the folding loop below. + SmallVector originalOperands; bool changed = false; - int64_t iteration = 0; - do { - // Check if the iteration limit was reached. - if (iteration++ >= config.maxIterations && - config.maxIterations != GreedyRewriteConfig::kNoLimit) - break; + int64_t numRewrites = 0; + while (!worklist.empty() && + (numRewrites < config.maxNumRewrites || + config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { + auto *op = popFromWorklist(); - worklist.clear(); - worklistMap.clear(); + // Nulls get added to the worklist when operations are removed, ignore + // them. + if (op == nullptr) + continue; - if (!config.useTopDownTraversal) { - // Add operations to the worklist in postorder. - region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) - addToWorklist(op); - }); - } else { - // Add all nested operations to the worklist in preorder. - region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) { - worklist.push_back(op); - return WalkResult::advance(); - } - return WalkResult::skip(); - }); + LLVM_DEBUG({ + logger.getOStream() << "\n"; + logger.startLine() << logLineComment; + logger.startLine() << "Processing operation : '" << op->getName() << "'(" + << op << ") {\n"; + logger.indent(); + + // If the operation has no regions, just print it here. + if (op->getNumRegions() == 0) { + op->print( + logger.startLine(), + OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); + logger.getOStream() << "\n\n"; + } + }); - // Reverse the list so our pop-back loop processes them in-order. - std::reverse(worklist.begin(), worklist.end()); - // Remember the reverse index. - for (size_t i = 0, e = worklist.size(); i != e; ++i) - worklistMap[worklist[i]] = i; + // If the operation is trivially dead - remove it. + if (isOpTriviallyDead(op)) { + notifyOperationRemoved(op); + op->erase(); + changed = true; + + LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); + continue; } - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; + // Collects all the operands and result uses of the given `op` into work + // list. Also remove `op` and nested ops from worklist. + originalOperands.assign(op->operand_begin(), op->operand_end()); + auto preReplaceAction = [&](Operation *op) { + // Add the operands to the worklist for visitation. + addOperandsToWorklist(originalOperands); - changed = false; - int64_t numRewrites = 0; - while (!worklist.empty() && - (numRewrites < config.maxNumRewrites || - config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { - auto *op = popFromWorklist(); + // Add all the users of the result to the worklist so we make sure + // to revisit them. + for (auto result : op->getResults()) + for (auto *userOp : result.getUsers()) + addToWorklist(userOp); - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) - continue; + notifyOperationRemoved(op); + }; - LLVM_DEBUG({ - logger.getOStream() << "\n"; - logger.startLine() << logLineComment; - logger.startLine() << "Processing operation : '" << op->getName() - << "'(" << op << ") {\n"; - logger.indent(); - - // If the operation has no regions, just print it here. - if (op->getNumRegions() == 0) { - op->print( - logger.startLine(), - OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); - logger.getOStream() << "\n\n"; - } - }); + // Add the given operation to the worklist. + auto collectOps = [this](Operation *op) { addToWorklist(op); }; - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); - changed = true; + // Try to fold this op. + bool inPlaceUpdate; + if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, + &inPlaceUpdate)))) { + LLVM_DEBUG(logResultWithLine("success", "operation was folded")); - LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); + changed = true; + if (!inPlaceUpdate) continue; - } - - // Collects all the operands and result uses of the given `op` into work - // list. Also remove `op` and nested ops from worklist. - originalOperands.assign(op->operand_begin(), op->operand_end()); - auto preReplaceAction = [&](Operation *op) { - // Add the operands to the worklist for visitation. - addOperandsToWorklist(originalOperands); - - // Add all the users of the result to the worklist so we make sure - // to revisit them. - for (auto result : op->getResults()) - for (auto *userOp : result.getUsers()) - addToWorklist(userOp); - - notifyOperationRemoved(op); - }; - - // Add the given operation to the worklist. - auto collectOps = [this](Operation *op) { addToWorklist(op); }; - - // Try to fold this op. - bool inPlaceUpdate; - if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, - &inPlaceUpdate)))) { - LLVM_DEBUG(logResultWithLine("success", "operation was folded")); - - changed = true; - if (!inPlaceUpdate) - continue; - } + } - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do - // here. + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do + // here. #ifndef NDEBUG auto canApply = [&](const Pattern &pattern) { LLVM_DEBUG({ @@ -304,16 +258,9 @@ changed = true; ++numRewrites; } - } - - // After applying patterns, make sure that the CFG of each of the regions - // is kept up to date. - if (config.enableRegionSimplification) - changed |= succeeded(simplifyRegions(*this, region)); - } while (changed); + } - // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return !changed; + return changed; } void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { @@ -321,12 +268,12 @@ SmallVector ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (config.scope == region) { - // All gathered ops are in fact ancestors. - for (Operation *op : ancestors) - addSingleOpToWorklist(op); - break; - } + if (config.scope == region) { + // All gathered ops are in fact ancestors. + for (Operation *op : ancestors) + addSingleOpToWorklist(op); + break; + } op = region->getParentOp(); if (!op) break; @@ -434,12 +381,96 @@ return failure(); } -/// Rewrite the regions of the specified operation, which must be isolated from -/// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return success if no more patterns can be matched -/// in the result operation regions. Note: This does not apply patterns to the -/// top-level operation itself. -/// +//===----------------------------------------------------------------------===// +// RegionPatternRewriteDriver +//===----------------------------------------------------------------------===// + +namespace { +/// This driver simplfies all ops in a region. +class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver { +public: + explicit RegionPatternRewriteDriver(MLIRContext *ctx, + const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, + Region ®ions); + + /// Simplify ops inside `region` and simplify the region itself. Return + /// success if the transformation converged. + LogicalResult simplify() &&; + +private: + /// The region that is simplified. + Region ®ion; +}; +} // namespace + +RegionPatternRewriteDriver::RegionPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, Region ®ion) + : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) { + // Populate strict mode ops. + if (config.strictMode != GreedyRewriteStrictness::AnyOp) { + region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); + } +} + +LogicalResult RegionPatternRewriteDriver::simplify() && { + auto insertKnownConstant = [&](Operation *op) { + // Check for existing constants when populating the worklist. This avoids + // accidentally reversing the constant order during processing. + Attribute constValue; + if (matchPattern(op, m_Constant(&constValue))) + if (!folder.insertKnownConstant(op, constValue)) + return true; + return false; + }; + + bool changed = false; + int64_t iteration = 0; + do { + // Check if the iteration limit was reached. + if (iteration++ >= config.maxIterations && + config.maxIterations != GreedyRewriteConfig::kNoLimit) + break; + + worklist.clear(); + worklistMap.clear(); + + if (!config.useTopDownTraversal) { + // Add operations to the worklist in postorder. + region.walk([&](Operation *op) { + if (!insertKnownConstant(op)) + addToWorklist(op); + }); + } else { + // Add all nested operations to the worklist in preorder. + region.walk([&](Operation *op) { + if (!insertKnownConstant(op)) { + worklist.push_back(op); + return WalkResult::advance(); + } + return WalkResult::skip(); + }); + + // Reverse the list so our pop-back loop processes them in-order. + std::reverse(worklist.begin(), worklist.end()); + // Remember the reverse index. + for (size_t i = 0, e = worklist.size(); i != e; ++i) + worklistMap[worklist[i]] = i; + } + + changed = processWorklist(); + + // After applying patterns, make sure that the CFG of each of the regions + // is kept up to date. + if (config.enableRegionSimplification) + changed |= succeeded(simplifyRegions(*this, region)); + } while (changed); + + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. + return success(!changed); +} + LogicalResult mlir::applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, @@ -455,13 +486,14 @@ config.scope = ®ion; // Start the pattern driver. - GreedyPatternRewriteDriver driver(region.getContext(), patterns, config); - bool converged = std::move(driver).simplify(region); - LLVM_DEBUG(if (!converged) { + RegionPatternRewriteDriver driver(region.getContext(), patterns, config, + region); + LogicalResult converged = std::move(driver).simplify(); + LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; }); - return success(converged); + return converged; } //===----------------------------------------------------------------------===// @@ -469,32 +501,16 @@ //===----------------------------------------------------------------------===// namespace { - -/// This is a specialized GreedyPatternRewriteDriver to apply patterns and -/// perform folding for a supplied set of ops. It repeatedly simplifies while -/// restricting the rewrites to only the provided set of ops or optionally -/// to those directly affected by it (result users or operand providers). Parent -/// ops are not considered. +/// This driver simplfies a list of ops. class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, - llvm::SmallDenseSet *survivingOps = nullptr) - : GreedyPatternRewriteDriver(ctx, patterns, config), - survivingOps(survivingOps) {} - - /// Performs the specified rewrites on `ops` while also trying to fold these - /// ops. `strictMode` controls which other ops are simplified. Only ops - /// within the given scope region are added to the worklist. If no scope is - /// specified, it assumed to be closest common region of all `ops`. - /// - /// Note that ops in `ops` could be erased as a result of folding, becoming - /// dead, or via pattern rewrites. The return value indicates convergence. - /// - /// All erased ops are stored in `erased`. - LogicalResult simplifyLocally(ArrayRef op, - bool *changed = nullptr) &&; + const GreedyRewriteConfig &config, ArrayRef ops, + llvm::SmallDenseSet *survivingOps = nullptr); + + /// Simplify `ops`. Return `success` if the transformation converged. + LogicalResult simplify(ArrayRef ops, bool *changed = nullptr) &&; private: void notifyOperationRemoved(Operation *op) override { @@ -508,98 +524,33 @@ /// of ops. llvm::SmallDenseSet *const survivingOps = nullptr; }; - } // namespace -LogicalResult -MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, - bool *changed) && { +MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, ArrayRef ops, + llvm::SmallDenseSet *survivingOps) + : GreedyPatternRewriteDriver(ctx, patterns, config), + survivingOps(survivingOps) { + if (config.strictMode != GreedyRewriteStrictness::AnyOp) + strictModeFilteredOps.insert(ops.begin(), ops.end()); + if (survivingOps) { survivingOps->clear(); survivingOps->insert(ops.begin(), ops.end()); } +} - if (config.strictMode != GreedyRewriteStrictness::AnyOp) { - strictModeFilteredOps.clear(); - strictModeFilteredOps.insert(ops.begin(), ops.end()); - } - - if (changed) - *changed = false; - worklist.clear(); - worklistMap.clear(); +LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef ops, + bool *changed) && { + // Populate the initial worklist. for (Operation *op : ops) addSingleOpToWorklist(op); - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; - int64_t numRewrites = 0; - while (!worklist.empty() && - (numRewrites < config.maxNumRewrites || - config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { - Operation *op = popFromWorklist(); - - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) - continue; - - assert((config.strictMode == GreedyRewriteStrictness::AnyOp || - strictModeFilteredOps.contains(op)) && - "unexpected op was inserted under strict mode"); - - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); - if (changed) - *changed = true; - continue; - } - - // Collects all the operands and result uses of the given `op` into work - // list. Also remove `op` and nested ops from worklist. - originalOperands.assign(op->operand_begin(), op->operand_end()); - auto preReplaceAction = [&](Operation *op) { - // Add the operands to the worklist for visitation. - addOperandsToWorklist(originalOperands); - - // Add all the users of the result to the worklist so we make sure - // to revisit them. - for (Value result : op->getResults()) { - for (Operation *userOp : result.getUsers()) - addToWorklist(userOp); - } - - notifyOperationRemoved(op); - }; - - // Add the given operation generated by the folder to the worklist. - auto processGeneratedConstants = [this](Operation *op) { - notifyOperationInserted(op); - }; - - // Try to fold this op. - bool inPlaceUpdate; - if (succeeded(folder.tryToFold(op, processGeneratedConstants, - preReplaceAction, &inPlaceUpdate))) { - if (changed) - *changed = true; - if (!inPlaceUpdate) { - // Op has been erased. - continue; - } - } - - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do - // here. - if (succeeded(matcher.matchAndRewrite(op, *this))) { - if (changed) - *changed = true; - ++numRewrites; - } - } + // Process ops on the worklist. + bool result = processWorklist(); + if (changed) + *changed = result; return success(worklist.empty()); } @@ -658,8 +609,9 @@ // Start the pattern driver. llvm::SmallDenseSet surviving; MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - config, allErased ? &surviving : nullptr); - LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); + config, ops, + allErased ? &surviving : nullptr); + LogicalResult converged = std::move(driver).simplify(ops, changed); if (allErased) *allErased = surviving.empty(); LLVM_DEBUG(if (failed(converged)) {