diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -34,15 +34,16 @@ namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly -/// applies the locally optimal patterns in a roughly "bottom up" way. +/// applies the locally optimal patterns. class GreedyPatternRewriteDriver : public PatternRewriter { -public: +protected: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config); - /// Simplify the ops within the given region. - bool simplify(Region ®ion) &&; + /// Process ops until the worklist is empty or `config.maxNumRewrites` is + /// reached. Return `true` if any IR was changed. + bool processWorklist(); /// Add the given operation and its ancestors to the worklist. void addToWorklist(Operation *op); @@ -57,7 +58,6 @@ /// in-place. void finalizeRootUpdate(Operation *op) override; -protected: /// Add the given operation to the worklist. void addSingleOpToWorklist(Operation *op); @@ -98,7 +98,6 @@ /// Non-pattern based folder for operations. OperationFolder folder; -protected: /// Configuration information for how to simplify. const GreedyRewriteConfig config; @@ -127,7 +126,7 @@ matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(Region ®ion) && { +bool GreedyPatternRewriteDriver::processWorklist() { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -146,130 +145,80 @@ }; #endif - auto insertKnownConstant = [&](Operation *op) { - // Check for existing constants when populating the worklist. This avoids - // accidentally reversing the constant order during processing. - Attribute constValue; - if (matchPattern(op, m_Constant(&constValue))) - if (!folder.insertKnownConstant(op, constValue)) - return true; - return false; - }; - - // Populate strict mode ops. - if (config.strictMode != GreedyRewriteStrictness::AnyOp) { - strictModeFilteredOps.clear(); - region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); - } + // These are scratch vectors used in the folding loop below. + SmallVector originalOperands; bool changed = false; - int64_t iteration = 0; - do { - // Check if the iteration limit was reached. - if (iteration++ >= config.maxIterations && - config.maxIterations != GreedyRewriteConfig::kNoLimit) - break; + int64_t numRewrites = 0; + while (!worklist.empty() && + (numRewrites < config.maxNumRewrites || + config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { + auto *op = popFromWorklist(); - worklist.clear(); - worklistMap.clear(); + // Nulls get added to the worklist when operations are removed, ignore + // them. + if (op == nullptr) + continue; - if (!config.useTopDownTraversal) { - // Add operations to the worklist in postorder. - region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) - addToWorklist(op); - }); - } else { - // Add all nested operations to the worklist in preorder. - region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) { - worklist.push_back(op); - return WalkResult::advance(); - } - return WalkResult::skip(); - }); + LLVM_DEBUG({ + logger.getOStream() << "\n"; + logger.startLine() << logLineComment; + logger.startLine() << "Processing operation : '" << op->getName() << "'(" + << op << ") {\n"; + logger.indent(); + + // If the operation has no regions, just print it here. + if (op->getNumRegions() == 0) { + op->print( + logger.startLine(), + OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); + logger.getOStream() << "\n\n"; + } + }); - // Reverse the list so our pop-back loop processes them in-order. - std::reverse(worklist.begin(), worklist.end()); - // Remember the reverse index. - for (size_t i = 0, e = worklist.size(); i != e; ++i) - worklistMap[worklist[i]] = i; + // If the operation is trivially dead - remove it. + if (isOpTriviallyDead(op)) { + notifyOperationRemoved(op); + op->erase(); + changed = true; + + LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); + continue; } - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; + // Collects all the operands and result uses of the given `op` into work + // list. Also remove `op` and nested ops from worklist. + originalOperands.assign(op->operand_begin(), op->operand_end()); + auto preReplaceAction = [&](Operation *op) { + // Add the operands to the worklist for visitation. + addOperandsToWorklist(originalOperands); - changed = false; - int64_t numRewrites = 0; - while (!worklist.empty() && - (numRewrites < config.maxNumRewrites || - config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { - auto *op = popFromWorklist(); + // Add all the users of the result to the worklist so we make sure + // to revisit them. + for (auto result : op->getResults()) + for (auto *userOp : result.getUsers()) + addToWorklist(userOp); - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) - continue; + notifyOperationRemoved(op); + }; - LLVM_DEBUG({ - logger.getOStream() << "\n"; - logger.startLine() << logLineComment; - logger.startLine() << "Processing operation : '" << op->getName() - << "'(" << op << ") {\n"; - logger.indent(); - - // If the operation has no regions, just print it here. - if (op->getNumRegions() == 0) { - op->print( - logger.startLine(), - OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); - logger.getOStream() << "\n\n"; - } - }); + // Add the given operation to the worklist. + auto collectOps = [this](Operation *op) { addToWorklist(op); }; - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); - changed = true; + // Try to fold this op. + bool inPlaceUpdate; + if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, + &inPlaceUpdate)))) { + LLVM_DEBUG(logResultWithLine("success", "operation was folded")); - LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); + changed = true; + if (!inPlaceUpdate) continue; - } - - // Collects all the operands and result uses of the given `op` into work - // list. Also remove `op` and nested ops from worklist. - originalOperands.assign(op->operand_begin(), op->operand_end()); - auto preReplaceAction = [&](Operation *op) { - // Add the operands to the worklist for visitation. - addOperandsToWorklist(originalOperands); - - // Add all the users of the result to the worklist so we make sure - // to revisit them. - for (auto result : op->getResults()) - for (auto *userOp : result.getUsers()) - addToWorklist(userOp); - - notifyOperationRemoved(op); - }; - - // Add the given operation to the worklist. - auto collectOps = [this](Operation *op) { addToWorklist(op); }; - - // Try to fold this op. - bool inPlaceUpdate; - if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, - &inPlaceUpdate)))) { - LLVM_DEBUG(logResultWithLine("success", "operation was folded")); - - changed = true; - if (!inPlaceUpdate) - continue; - } + } - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do - // here. + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do + // here. #ifndef NDEBUG auto canApply = [&](const Pattern &pattern) { LLVM_DEBUG({ @@ -304,16 +253,9 @@ changed = true; ++numRewrites; } - } - - // After applying patterns, make sure that the CFG of each of the regions - // is kept up to date. - if (config.enableRegionSimplification) - changed |= succeeded(simplifyRegions(*this, region)); - } while (changed); + } - // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return !changed; + return changed; } void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { @@ -321,12 +263,12 @@ SmallVector ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (config.scope == region) { - // All gathered ops are in fact ancestors. - for (Operation *op : ancestors) - addSingleOpToWorklist(op); - break; - } + if (config.scope == region) { + // All gathered ops are in fact ancestors. + for (Operation *op : ancestors) + addSingleOpToWorklist(op); + break; + } op = region->getParentOp(); if (!op) break; @@ -434,6 +376,96 @@ return failure(); } +//===----------------------------------------------------------------------===// +// RegionPatternRewriteDriver +//===----------------------------------------------------------------------===// + +namespace { +/// This driver simplfies all ops in a region in a roughly "bottom up" way. +class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver { +public: + explicit RegionPatternRewriteDriver(MLIRContext *ctx, + const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, + Region ®ions); + + /// Simplify ops inside the specified regions and simplify the regions + /// themselves. Return success if the transformation converged. + LogicalResult simplify() &&; + +private: + /// The region that is simplified. + Region ®ion; +}; +} // namespace + +RegionPatternRewriteDriver::RegionPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, Region ®ion) + : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) { + // Populate strict mode ops. + if (config.strictMode != GreedyRewriteStrictness::AnyOp) { + region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); + } +} + +LogicalResult RegionPatternRewriteDriver::simplify() && { + auto insertKnownConstant = [&](Operation *op) { + // Check for existing constants when populating the worklist. This avoids + // accidentally reversing the constant order during processing. + Attribute constValue; + if (matchPattern(op, m_Constant(&constValue))) + if (!folder.insertKnownConstant(op, constValue)) + return true; + return false; + }; + + bool changed = false; + int64_t iteration = 0; + do { + // Check if the iteration limit was reached. + if (iteration++ >= config.maxIterations && + config.maxIterations != GreedyRewriteConfig::kNoLimit) + break; + + worklist.clear(); + worklistMap.clear(); + + if (!config.useTopDownTraversal) { + // Add operations to the worklist in postorder. + region.walk([&](Operation *op) { + if (!insertKnownConstant(op)) + addToWorklist(op); + }); + } else { + // Add all nested operations to the worklist in preorder. + region.walk([&](Operation *op) { + if (!insertKnownConstant(op)) { + worklist.push_back(op); + return WalkResult::advance(); + } + return WalkResult::skip(); + }); + + // Reverse the list so our pop-back loop processes them in-order. + std::reverse(worklist.begin(), worklist.end()); + // Remember the reverse index. + for (size_t i = 0, e = worklist.size(); i != e; ++i) + worklistMap[worklist[i]] = i; + } + + changed = processWorklist(); + + // After applying patterns, make sure that the CFG of each of the regions + // is kept up to date. + if (config.enableRegionSimplification) + changed |= succeeded(simplifyRegions(*this, region)); + } while (changed); + + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. + return success(!changed); +} + /// Rewrite the regions of the specified operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy /// work-list driven manner. Return success if no more patterns can be matched @@ -455,13 +487,14 @@ config.scope = ®ion; // Start the pattern driver. - GreedyPatternRewriteDriver driver(region.getContext(), patterns, config); - bool converged = std::move(driver).simplify(region); - LLVM_DEBUG(if (!converged) { + RegionPatternRewriteDriver driver(region.getContext(), patterns, config, + region); + LogicalResult converged = std::move(driver).simplify(); + LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; }); - return success(converged); + return converged; } //===----------------------------------------------------------------------===// @@ -469,20 +502,13 @@ //===----------------------------------------------------------------------===// namespace { - -/// This is a specialized GreedyPatternRewriteDriver to apply patterns and -/// perform folding for a supplied set of ops. It repeatedly simplifies while -/// restricting the rewrites to only the provided set of ops or optionally -/// to those directly affected by it (result users or operand providers). Parent -/// ops are not considered. +/// This driver simplfies a list of ops. class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, - llvm::SmallDenseSet *survivingOps = nullptr) - : GreedyPatternRewriteDriver(ctx, patterns, config), - survivingOps(survivingOps) {} + const GreedyRewriteConfig &config, ArrayRef ops, + llvm::SmallDenseSet *survivingOps = nullptr); /// Performs the specified rewrites on `ops` while also trying to fold these /// ops. `strictMode` controls which other ops are simplified. Only ops @@ -493,8 +519,7 @@ /// dead, or via pattern rewrites. The return value indicates convergence. /// /// All erased ops are stored in `erased`. - LogicalResult simplifyLocally(ArrayRef op, - bool *changed = nullptr) &&; + LogicalResult simplify(bool *changed = nullptr) &&; private: void notifyOperationRemoved(Operation *op) override { @@ -503,103 +528,39 @@ survivingOps->erase(op); } + const std::vector ops; + /// An optional set of ops that were erased. This set is populated /// at the beginning of `simplifyLocally` with the inititally provided list /// of ops. llvm::SmallDenseSet *const survivingOps = nullptr; }; - } // namespace -LogicalResult -MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops, - bool *changed) && { +MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, ArrayRef ops, + llvm::SmallDenseSet *survivingOps) + : GreedyPatternRewriteDriver(ctx, patterns, config), + ops(ops.begin(), ops.end()), survivingOps(survivingOps) { + if (config.strictMode != GreedyRewriteStrictness::AnyOp) + strictModeFilteredOps.insert(ops.begin(), ops.end()); + if (survivingOps) { survivingOps->clear(); survivingOps->insert(ops.begin(), ops.end()); } +} - if (config.strictMode != GreedyRewriteStrictness::AnyOp) { - strictModeFilteredOps.clear(); - strictModeFilteredOps.insert(ops.begin(), ops.end()); - } - - if (changed) - *changed = false; - worklist.clear(); - worklistMap.clear(); +LogicalResult MultiOpPatternRewriteDriver::simplify(bool *changed) && { + // Populate the initial worklist. for (Operation *op : ops) addSingleOpToWorklist(op); - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; - int64_t numRewrites = 0; - while (!worklist.empty() && - (numRewrites < config.maxNumRewrites || - config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { - Operation *op = popFromWorklist(); - - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) - continue; - - assert((config.strictMode == GreedyRewriteStrictness::AnyOp || - strictModeFilteredOps.contains(op)) && - "unexpected op was inserted under strict mode"); - - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); - if (changed) - *changed = true; - continue; - } - - // Collects all the operands and result uses of the given `op` into work - // list. Also remove `op` and nested ops from worklist. - originalOperands.assign(op->operand_begin(), op->operand_end()); - auto preReplaceAction = [&](Operation *op) { - // Add the operands to the worklist for visitation. - addOperandsToWorklist(originalOperands); - - // Add all the users of the result to the worklist so we make sure - // to revisit them. - for (Value result : op->getResults()) { - for (Operation *userOp : result.getUsers()) - addToWorklist(userOp); - } - - notifyOperationRemoved(op); - }; - - // Add the given operation generated by the folder to the worklist. - auto processGeneratedConstants = [this](Operation *op) { - notifyOperationInserted(op); - }; - - // Try to fold this op. - bool inPlaceUpdate; - if (succeeded(folder.tryToFold(op, processGeneratedConstants, - preReplaceAction, &inPlaceUpdate))) { - if (changed) - *changed = true; - if (!inPlaceUpdate) { - // Op has been erased. - continue; - } - } - - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do - // here. - if (succeeded(matcher.matchAndRewrite(op, *this))) { - if (changed) - *changed = true; - ++numRewrites; - } - } + // Process ops on the worklist. + bool result = processWorklist(); + if (changed) + *changed = result; return success(worklist.empty()); } @@ -658,8 +619,9 @@ // Start the pattern driver. llvm::SmallDenseSet surviving; MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - config, allErased ? &surviving : nullptr); - LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); + config, ops, + allErased ? &surviving : nullptr); + LogicalResult converged = std::move(driver).simplify(changed); if (allErased) *allErased = surviving.empty(); LLVM_DEBUG(if (failed(converged)) {