diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -21,6 +21,15 @@ /// This class allows control over how the GreedyPatternRewriteDriver works. class GreedyRewriteConfig { public: + enum class Strictness { + /// No restrictions wrt. which ops are processed; i.e., no strict mode. + AnyOp, + /// Only pre-existing and newly created ops are processed. + ExistingAndNewOps, + /// Only pre-existing ops are processed. + ExistingOps + }; + /// This specifies the order of initial traversal that populates the rewriters /// worklist. When set to true, it walks the operations top-down, which is /// generally more efficient in compile time. When set to false, its initial @@ -42,6 +51,9 @@ int64_t maxNumRewrites = kNoLimit; static constexpr int64_t kNoLimit = -1; + + /// Strict mode controls which ops are added to the worklist. + Strictness strictMode = Strictness::AnyOp; }; //===----------------------------------------------------------------------===// @@ -52,51 +64,47 @@ /// above, by repeatedly applying the highest benefit patterns in a greedy /// work-list driven manner. /// -/// This variant may stop after a predefined number of iterations, see the -/// alternative below to provide a specific number of iterations before stopping -/// in absence of convergence. +/// This transformation may stop after a predefined number of iterations (if no +/// convergence was reached yet). This limit can be controlled with +/// GreedyRewriteConfig. /// /// Return success if the iterative process converged and no more patterns can -/// be matched in the result operation regions. +/// be matched in the specified regions. /// -/// Note: This does not apply patterns to the top-level operation itself. +/// Note: This does not apply patterns to the owners of the specified regions. /// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// -/// You may configure several aspects of this with GreedyRewriteConfig. LogicalResult applyPatternsAndFoldGreedily( MutableArrayRef regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig()); -/// Rewrite the given regions, which must be isolated from above. +/// Rewrite the regions of the given op, which must be isolated from above. inline LogicalResult applyPatternsAndFoldGreedily( Operation *op, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig()) { return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config); } -/// Applies the specified patterns on `op` alone while also trying to fold it, -/// by selecting the highest benefits patterns in a greedy manner. Returns -/// success if no more patterns can be matched. `erased` is set to true if `op` -/// was folded away or erased as a result of becoming dead. Note: This does not -/// apply any patterns recursively to the regions of `op`. -LogicalResult applyOpPatternsAndFold(Operation *op, - const FrozenRewritePatternSet &patterns, - bool *erased = nullptr); - -/// Applies the specified rewrite patterns on `ops` while also trying to fold -/// these ops as well as any other ops that were in turn created due to such -/// rewrites. Furthermore, any pre-existing ops in the IR outside of `ops` -/// remain completely unmodified if `strict` is set to true. If `strict` is -/// false, other operations that use results of rewritten ops or supply operands -/// to such ops are in turn simplified; any other ops still remain unmodified -/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a -/// result of folding, becoming dead, or via pattern rewrites. If more far -/// reaching simplification is desired, applyPatternsAndFoldGreedily should be -/// used. Returns true if at all any IR was rewritten. -bool applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict); +/// Apply the specified rewrite patterns on `ops` while also trying to fold +/// these ops, by repeatedly applying the highest benefit patterns in a greedy +/// work-list driven manner. +/// +/// This transformation may stop after a predefined number of iterations (if no +/// convergence was reached yet). This limit can be controlled with +/// GreedyRewriteConfig. +/// +/// Return success if the iterative process converged and no more patterns can +/// be matched in the specified regions. +/// +/// Note that ops could be erased as a result of folding, becoming dead, or via +/// pattern rewrites. `allOpsErased` is set to true if all ops in `ops` were +/// erased. `changed` is set to true if the IR is modified at all. +LogicalResult +applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr, bool *allOpsErased = nullptr); } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -238,5 +238,7 @@ AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(copyOps, frozenPatterns, /*strict=*/true); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(copyOps, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -105,5 +105,7 @@ if (isa(op)) opsToSimplify.push_back(op); }); - (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, /*strict=*/true); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -321,8 +321,11 @@ // Simplify/canonicalize the affine.for. RewritePatternSet patterns(res.getContext()); AffineForOp::getCanonicalizationPatterns(patterns, res.getContext()); - bool erased; - (void)applyOpPatternsAndFold(res, std::move(patterns), &erased); + bool erased, changed; + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingOps; + (void)applyOpPatternsAndFold({res}, std::move(patterns), config, + &changed, &erased); if (!erased && !prologue) prologue = res; diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -413,9 +413,12 @@ // in which case we return with `folded` being set. RewritePatternSet patterns(ifOp.getContext()); AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); - bool erased; + bool erased, changed; FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingOps; + (void)applyOpPatternsAndFold({ifOp}, frozenPatterns, config, &changed, + &erased); if (erased) { if (folded) *folded = true; diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -60,10 +60,13 @@ // matching in above iteration. Besides, erase op not-in-range may end up in // invalid module, so `applyOpPatternsAndFold` should come before that // transform. - for (Operation *op : opsInRange) + for (Operation *op : opsInRange) { // `applyOpPatternsAndFold` returns whether the op is convered. Omit it // because we don't have expectation this reduction will be success or not. - (void)applyOpPatternsAndFold(op, patterns); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingOps; + (void)applyOpPatternsAndFold({op}, patterns, config); + } if (eraseOpNotInRange) for (Operation *op : opsNotInRange) { diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -32,45 +32,67 @@ namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly -/// applies the locally optimal patterns in a roughly "bottom up" way. -class GreedyPatternRewriteDriver : public PatternRewriter { -public: - explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config); +/// applies the locally optimal patterns. +class GreedyPatternRewriteDriver : protected PatternRewriter { +protected: + explicit GreedyPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, + const std::optional> &scope); - /// Simplify the operations within the given regions. - bool simplify(MutableArrayRef regions); + //===--------------------------------------------------------------------===// + // Worklist + //===--------------------------------------------------------------------===// - /// Add the given operation to the worklist. Parent ops may or may not be - /// added to the worklist, depending on the type of rewrite driver. By - /// default, parent ops are added. - virtual void addToWorklist(Operation *op); + // Look over the provided operands for any defining operations that should + // be re-added to the worklist. This function should be called when an + // operation is modified or removed, as it may trigger further + // simplifications. + void addOperandsToWorklist(ValueRange operands); + + /// Add the given operation to the worklist. In strict mode, certain ops are + /// excluded. + void addToWorklist(Operation *op); /// Pop the next operation from the worklist. Operation *popFromWorklist(); + /// Process ops until the worklist is empty or `config.maxNumRewrites` is + /// reached. Return `true` if any IR was changed. + bool processWorklist(); + /// If the specified operation is in the worklist, remove it. void removeFromWorklist(Operation *op); + /// The worklist for this transformation keeps track of the operations that + /// need to be revisited, plus their index in the worklist. This allows us to + /// efficiently remove operations from the worklist when they are erased, even + /// if they aren't the root of a pattern. + std::vector worklist; + DenseMap worklistMap; + + /// The list of ops we are restricting our rewrites to in strict mode. + /// These include the supplied set of ops as well as new ops created while + /// rewriting those ops. This set is not maintained when strict mode is off. + llvm::SmallDenseSet strictModeFilteredOps; + + //===--------------------------------------------------------------------===// + // Listener + //===--------------------------------------------------------------------===// + /// Notifies the driver that the specified operation may have been modified /// in-place. - void finalizeRootUpdate(Operation *op) override; + void finalizeRootUpdate(Operation *op); -protected: - /// Add the given operation to the worklist. - void addSingleOpToWorklist(Operation *op); + /// PatternRewriter hook for notifying match failure reasons. + LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback) override; // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. void notifyOperationInserted(Operation *op) override; - // Look over the provided operands for any defining operations that should - // be re-added to the worklist. This function should be called when an - // operation is modified or removed, as it may trigger further - // simplifications. - void addOperandsToWorklist(ValueRange operands); - // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. void notifyOperationRemoved(Operation *op) override; @@ -80,35 +102,18 @@ // before the root is changed. void notifyRootReplaced(Operation *op, ValueRange replacement) override; - /// PatternRewriter hook for erasing a dead operation. - void eraseOp(Operation *op) override; - - /// PatternRewriter hook for notifying match failure reasons. - LogicalResult - notifyMatchFailure(Location loc, - function_ref reasonCallback) override; - - /// The low-level pattern applicator. - PatternApplicator matcher; - - /// The worklist for this transformation keeps track of the operations that - /// need to be revisited, plus their index in the worklist. This allows us to - /// efficiently remove operations from the worklist when they are erased, even - /// if they aren't the root of a pattern. - std::vector worklist; - DenseMap worklistMap; - /// Non-pattern based folder for operations. OperationFolder folder; -protected: /// Configuration information for how to simplify. - GreedyRewriteConfig config; + const GreedyRewriteConfig config; private: - /// Only ops within this scope are simplified. This is set at the beginning - /// of `simplify()` to the current scope the rewriter operates on. - DenseSet scope; + /// The low-level pattern applicator. + PatternApplicator matcher; + + /// Only ops within this scope are simplified. + const std::optional> scope; #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -119,18 +124,18 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + const GreedyRewriteConfig &config, + const std::optional> &scope) + : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns), + scope(scope) { + worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { - for (Region &r : regions) - scope.insert(&r); - +bool GreedyPatternRewriteDriver::processWorklist() { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -149,183 +154,154 @@ }; #endif - auto insertKnownConstant = [&](Operation *op) { - // Check for existing constants when populating the worklist. This avoids - // accidentally reversing the constant order during processing. - Attribute constValue; - if (matchPattern(op, m_Constant(&constValue))) - if (!folder.insertKnownConstant(op, constValue)) - return true; - return false; - }; + // These are scratch vectors used in the folding loop below. + SmallVector originalOperands, resultValues; bool changed = false; - int64_t iteration = 0; - do { - // Check if the iteration limit was reached. - if (iteration++ >= config.maxIterations && - config.maxIterations != GreedyRewriteConfig::kNoLimit) - break; + int64_t numRewrites = 0; + while (!worklist.empty() && + (numRewrites < config.maxNumRewrites || + config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { + auto *op = popFromWorklist(); - worklist.clear(); - worklistMap.clear(); + // Nulls get added to the worklist when operations are removed, ignore + // them. + if (op == nullptr) + continue; - if (!config.useTopDownTraversal) { - // Add operations to the worklist in postorder. - for (auto ®ion : regions) { - region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) - addToWorklist(op); - }); - } - } else { - // Add all nested operations to the worklist in preorder. - for (auto ®ion : regions) { - region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) { - worklist.push_back(op); - return WalkResult::advance(); - } - return WalkResult::skip(); - }); + LLVM_DEBUG({ + logger.getOStream() << "\n"; + logger.startLine() << logLineComment; + logger.startLine() << "Processing operation : '" << op->getName() << "'(" + << op << ") {\n"; + logger.indent(); + + // If the operation has no regions, just print it here. + if (op->getNumRegions() == 0) { + op->print( + logger.startLine(), + OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); + logger.getOStream() << "\n\n"; } + }); - // Reverse the list so our pop-back loop processes them in-order. - std::reverse(worklist.begin(), worklist.end()); - // Remember the reverse index. - for (size_t i = 0, e = worklist.size(); i != e; ++i) - worklistMap[worklist[i]] = i; + // If the operation is trivially dead - remove it. + if (isOpTriviallyDead(op)) { + notifyOperationRemoved(op); + op->erase(); + changed = true; + + LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); + continue; } - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; + // Collects all the operands and result uses of the given `op` into work + // list. Also remove `op` and nested ops from worklist. + originalOperands.assign(op->operand_begin(), op->operand_end()); + auto preReplaceAction = [&](Operation *op) { + // Add the operands to the worklist for visitation. + addOperandsToWorklist(originalOperands); + + // Add all the users of the result to the worklist so we make sure + // to revisit them. + for (auto result : op->getResults()) + for (auto *userOp : result.getUsers()) + addToWorklist(userOp); - changed = false; - int64_t numRewrites = 0; - while (!worklist.empty() && - (numRewrites < config.maxNumRewrites || - config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { - auto *op = popFromWorklist(); + notifyOperationRemoved(op); + }; + + // Add the given operation to the worklist. + auto collectOps = [this](Operation *op) { addToWorklist(op); }; - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) + // Try to fold this op. + bool inPlaceUpdate; + if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, + &inPlaceUpdate)))) { + LLVM_DEBUG(logResultWithLine("success", "operation was folded")); + + changed = true; + if (!inPlaceUpdate) continue; + } + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do + // here. +#ifndef NDEBUG + auto canApply = [&](const Pattern &pattern) { LLVM_DEBUG({ logger.getOStream() << "\n"; - logger.startLine() << logLineComment; - logger.startLine() << "Processing operation : '" << op->getName() - << "'(" << op << ") {\n"; + logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" + << op->getName() << " -> ("; + llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); + logger.getOStream() << ")' {\n"; logger.indent(); - - // If the operation has no regions, just print it here. - if (op->getNumRegions() == 0) { - op->print( - logger.startLine(), - OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); - logger.getOStream() << "\n\n"; - } }); + return true; + }; + auto onFailure = [&](const Pattern &pattern) { + LLVM_DEBUG(logResult("failure", "pattern failed to match")); + }; + auto onSuccess = [&](const Pattern &pattern) { + LLVM_DEBUG(logResult("success", "pattern applied successfully")); + return success(); + }; - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); - changed = true; - - LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); - continue; - } - - // Collects all the operands and result uses of the given `op` into work - // list. Also remove `op` and nested ops from worklist. - originalOperands.assign(op->operand_begin(), op->operand_end()); - auto preReplaceAction = [&](Operation *op) { - // Add the operands to the worklist for visitation. - addOperandsToWorklist(originalOperands); - - // Add all the users of the result to the worklist so we make sure - // to revisit them. - for (auto result : op->getResults()) - for (auto *userOp : result.getUsers()) - addToWorklist(userOp); - - notifyOperationRemoved(op); - }; - - // Add the given operation to the worklist. - auto collectOps = [this](Operation *op) { addToWorklist(op); }; - - // Try to fold this op. - bool inPlaceUpdate; - if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, - &inPlaceUpdate)))) { - LLVM_DEBUG(logResultWithLine("success", "operation was folded")); - - changed = true; - if (!inPlaceUpdate) - continue; - } - - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do - // here. -#ifndef NDEBUG - auto canApply = [&](const Pattern &pattern) { - LLVM_DEBUG({ - logger.getOStream() << "\n"; - logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" - << op->getName() << " -> ("; - llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); - logger.getOStream() << ")' {\n"; - logger.indent(); - }); - return true; - }; - auto onFailure = [&](const Pattern &pattern) { - LLVM_DEBUG(logResult("failure", "pattern failed to match")); - }; - auto onSuccess = [&](const Pattern &pattern) { - LLVM_DEBUG(logResult("success", "pattern applied successfully")); - return success(); - }; - - LogicalResult matchResult = - matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); - if (succeeded(matchResult)) - LLVM_DEBUG(logResultWithLine("success", "pattern matched")); - else - LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); + LogicalResult matchResult = + matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); + if (succeeded(matchResult)) + LLVM_DEBUG(logResultWithLine("success", "pattern matched")); + else + LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); #else - LogicalResult matchResult = matcher.matchAndRewrite(op, *this); + LogicalResult matchResult = matcher.matchAndRewrite(op, *this); #endif - if (succeeded(matchResult)) { - changed = true; - ++numRewrites; - } + if (succeeded(matchResult)) { + changed = true; + ++numRewrites; } + } - // After applying patterns, make sure that the CFG of each of the regions - // is kept up to date. - if (config.enableRegionSimplification) - changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed); + return changed; +} - // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return !changed; +void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) { + for (Value operand : operands) { + // If the use count of this operand is now < 2, we re-add the defining + // operation to the worklist. + // TODO: This is based on the fact that zero use operations + // may be deleted, and that single use values often have more + // canonicalization opportunities. + if (!operand || (!operand.use_empty() && !operand.hasOneUse())) + continue; + if (auto *defOp = operand.getDefiningOp()) + addToWorklist(defOp); + } } void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { + auto addOp = [&](Operation *op) { + // Check to see if the worklist already contains this op. + if (worklistMap.count(op)) + return; + + if (config.strictMode == GreedyRewriteConfig::Strictness::AnyOp || + strictModeFilteredOps.contains(op)) { + worklistMap[op] = worklist.size(); + worklist.push_back(op); + } + }; + // Gather potential ancestors while looking for a "scope" parent region. SmallVector ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (scope.contains(region)) { + if (!scope.has_value() || scope->contains(region)) { // All gathered ops are in fact ancestors. for (Operation *op : ancestors) - addSingleOpToWorklist(op); + addOp(op); break; } @@ -337,15 +313,6 @@ } } -void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { - // Check to see if the worklist already contains this op. - if (worklistMap.count(op)) - return; - - worklistMap[op] = worklist.size(); - worklist.push_back(op); -} - Operation *GreedyPatternRewriteDriver::popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); @@ -370,6 +337,8 @@ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); + if (config.strictMode == GreedyRewriteConfig::Strictness::ExistingAndNewOps) + strictModeFilteredOps.insert(op); addToWorklist(op); } @@ -381,21 +350,14 @@ addToWorklist(op); } -void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) { - for (Value operand : operands) { - // If the use count of this operand is now < 2, we re-add the defining - // operation to the worklist. - // TODO: This is based on the fact that zero use operations - // may be deleted, and that single use values often have more - // canonicalization opportunities. - if (!operand || (!operand.use_empty() && !operand.hasOneUse())) - continue; - if (auto *defOp = operand.getDefiningOp()) - addToWorklist(defOp); - } -} - void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Erase : '" << op->getName() << "'(" << op + << ")\n"; + }); + + if (config.strictMode != GreedyRewriteConfig::Strictness::AnyOp) + strictModeFilteredOps.erase(op); addOperandsToWorklist(op->getOperands()); op->walk([this](Operation *operation) { removeFromWorklist(operation); @@ -414,14 +376,6 @@ addToWorklist(user); } -void GreedyPatternRewriteDriver::eraseOp(Operation *op) { - LLVM_DEBUG({ - logger.startLine() << "** Erase : '" << op->getName() << "'(" << op - << ")\n"; - }); - PatternRewriter::eraseOp(op); -} - LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure( Location loc, function_ref reasonCallback) { LLVM_DEBUG({ @@ -432,135 +386,143 @@ return failure(); } -/// Rewrite the regions of the specified operation, which must be isolated from -/// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return success if no more patterns can be matched -/// in the result operation regions. Note: This does not apply patterns to the -/// top-level operation itself. -/// -LogicalResult -mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config) { - if (regions.empty()) - return success(); - - // The top-level operation must be known to be isolated from above to - // prevent performing canonicalizations on operations defined at or above - // the region containing 'op'. - auto regionIsIsolated = [](Region ®ion) { - return region.getParentOp()->hasTrait(); - }; - (void)regionIsIsolated; - assert(llvm::all_of(regions, regionIsIsolated) && - "patterns can only be applied to operations IsolatedFromAbove"); - - // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); - bool converged = driver.simplify(regions); - LLVM_DEBUG(if (!converged) { - llvm::dbgs() << "The pattern rewrite did not converge after scanning " - << config.maxIterations << " times\n"; - }); - return success(converged); -} - //===----------------------------------------------------------------------===// -// OpPatternRewriteDriver +// RegionPatternRewriteDriver //===----------------------------------------------------------------------===// namespace { -/// This is a simple driver for the PatternMatcher to apply patterns and perform -/// folding on a single op. It repeatedly applies locally optimal patterns. -class OpPatternRewriteDriver : public PatternRewriter { +/// This driver simplfies one or multiple regions. +class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver { public: - explicit OpPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) { - // Apply a simple cost model based solely on pattern benefit. - matcher.applyDefaultCostModel(); - } - - LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites, - bool &erased); - - // These are hooks implemented for PatternRewriter. -protected: - /// If an operation is about to be removed, mark it so that we can let clients - /// know. - void notifyOperationRemoved(Operation *op) override { - opErasedViaPatternRewrites = true; - } + explicit RegionPatternRewriteDriver(MLIRContext *ctx, + const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, + MutableArrayRef regions); - // When a root is going to be replaced, its removal will be notified as well. - // So there is nothing to do here. - void notifyRootReplaced(Operation *op, ValueRange replacement) override {} + /// Simplify ops inside the specified regions and simplify the regions + /// themselves. Return success if the transformation converged. + LogicalResult simplify() &&; private: - /// The low-level pattern applicator. - PatternApplicator matcher; + /// The regions that are simplified. + const MutableArrayRef regions; +}; +} // namespace - /// Non-pattern based folder for operations. - OperationFolder folder; +template +static DenseSet makePointerSet(MutableArrayRef elems) { + DenseSet result; + for (auto &e : elems) + result.insert(&e); + return result; +} - /// Set to true if the operation has been erased via pattern rewrites. - bool opErasedViaPatternRewrites = false; -}; +RegionPatternRewriteDriver::RegionPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, MutableArrayRef regions) + : GreedyPatternRewriteDriver(ctx, patterns, config, + makePointerSet(regions)), + regions(regions) { + for (auto ®ion : regions) { + if (config.strictMode != GreedyRewriteConfig::Strictness::AnyOp) { + region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); + } + } +} -} // namespace +LogicalResult RegionPatternRewriteDriver::simplify() && { + auto insertKnownConstant = [&](Operation *op) { + // Check for existing constants when populating the worklist. This avoids + // accidentally reversing the constant order during processing. + Attribute constValue; + if (matchPattern(op, m_Constant(&constValue))) + if (!folder.insertKnownConstant(op, constValue)) + return true; + return false; + }; -/// Performs the rewrites and folding only on `op`. The simplification -/// converges if the op is erased as a result of being folded, replaced, or -/// becoming dead, or no more changes happen in an iteration. Returns success if -/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op` -/// gets erased. -LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, - int64_t maxNumRewrites, - bool &erased) { bool changed = false; - erased = false; - opErasedViaPatternRewrites = false; - int64_t numRewrites = 0; - // Iterate until convergence or until maxNumRewrites. Deletion of the op as - // a result of being dead or folded is convergence. + int64_t iteration = 0; do { - if (numRewrites >= maxNumRewrites && - maxNumRewrites != GreedyRewriteConfig::kNoLimit) + // Check if the iteration limit was reached. + if (iteration++ >= config.maxIterations && + config.maxIterations != GreedyRewriteConfig::kNoLimit) break; - changed = false; - - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - op->erase(); - erased = true; - return success(); - } + worklist.clear(); + worklistMap.clear(); - // Try to fold this op. - bool inPlaceUpdate; - if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, - /*preReplaceAction=*/nullptr, - &inPlaceUpdate))) { - changed = true; - if (!inPlaceUpdate) { - erased = true; - return success(); + if (!config.useTopDownTraversal) { + // Add operations to the worklist in postorder. + for (Region ®ion : regions) { + region.walk([&](Operation *op) { + if (!insertKnownConstant(op)) + addToWorklist(op); + }); + } + } else { + // Add all nested operations to the worklist in preorder. + for (Region ®ion : regions) { + region.walk([&](Operation *op) { + if (!insertKnownConstant(op)) { + worklist.push_back(op); + return WalkResult::advance(); + } + return WalkResult::skip(); + }); } - } - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do here. - if (succeeded(matcher.matchAndRewrite(op, *this))) { - changed = true; - ++numRewrites; + // Reverse the list so our pop-back loop processes them in-order. + std::reverse(worklist.begin(), worklist.end()); + // Remember the reverse index. + for (size_t i = 0, e = worklist.size(); i != e; ++i) + worklistMap[worklist[i]] = i; } - if ((erased = opErasedViaPatternRewrites)) - return success(); + + changed = processWorklist(); + + // After applying patterns, make sure that the CFG of each of the regions + // is kept up to date. + if (config.enableRegionSimplification) + changed |= succeeded(simplifyRegions(*this, regions)); } while (changed); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return failure(changed); + return success(!changed); +} + +/// Rewrite the specified regions, which must be isolated from above, by +/// repeatedly applying the highest benefit patterns in a greedy work-list +/// driven manner. Return success if no more patterns can be matched in the +/// result operation regions. Note: This does not apply patterns to the owner +/// operations of the regions. +/// +LogicalResult +mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config) { + if (regions.empty()) + return success(); + + // The top-level operation must be known to be isolated from above to + // prevent performing canonicalizations on operations defined at or above + // the region containing 'op'. + auto regionIsIsolated = [](Region ®ion) { + return region.getParentOp()->hasTrait(); + }; + (void)regionIsIsolated; + assert(llvm::all_of(regions, regionIsIsolated) && + "patterns can only be applied to operations IsolatedFromAbove"); + + // Start the pattern driver. + RegionPatternRewriteDriver driver(regions[0].getContext(), patterns, config, + regions); + LogicalResult converged = std::move(driver).simplify(); + LLVM_DEBUG(if (failed(converged)) { + llvm::dbgs() << "The pattern rewrite did not converge after scanning " + << config.maxIterations << " times\n"; + }); + return converged; } //===----------------------------------------------------------------------===// @@ -568,180 +530,85 @@ //===----------------------------------------------------------------------===// namespace { - -/// This is a specialized GreedyPatternRewriteDriver to apply patterns and -/// perform folding for a supplied set of ops. It repeatedly simplifies while -/// restricting the rewrites to only the provided set of ops or optionally -/// to those directly affected by it (result users or operand providers). Parent -/// ops are not considered. +/// This driver simplfies one or multiple ops. class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - bool strict) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), - strictMode(strict) {} + const GreedyRewriteConfig &config, + ArrayRef ops); - bool simplifyLocally(ArrayRef op); + /// Simplify the specified ops. Return success if the transformation + /// converged. Optionally, `*changed` indicates whether the IR was modified. + /// `*erased` is a subset of the specified ops that were erased during the + /// transformation. + LogicalResult simplify(bool *changed = nullptr, + DenseSet *erased = nullptr) &&; - void addToWorklist(Operation *op) override { - if (!strictMode || strictModeFilteredOps.contains(op)) - GreedyPatternRewriteDriver::addSingleOpToWorklist(op); - } +protected: + void notifyOperationRemoved(Operation *op) override; private: - void notifyOperationInserted(Operation *op) override { - if (strictMode) - strictModeFilteredOps.insert(op); - GreedyPatternRewriteDriver::notifyOperationInserted(op); - } - - void notifyOperationRemoved(Operation *op) override { - GreedyPatternRewriteDriver::notifyOperationRemoved(op); - if (strictMode) - strictModeFilteredOps.erase(op); - } - - /// If `strictMode` is true, any pre-existing ops outside of - /// `strictModeFilteredOps` remain completely untouched by the rewrite driver. - /// If `strictMode` is false, operations that use results of (or supply - /// operands to) any rewritten ops stemming from the simplification of the - /// provided ops are in turn simplified; any other ops still remain untouched - /// (i.e., regardless of `strictMode`). - bool strictMode = false; + /// Ops to be processed. + const SetVector ops; - /// The list of ops we are restricting our rewrites to if `strictMode` is on. - /// These include the supplied set of ops as well as new ops created while - /// rewriting those ops. This set is not maintained when strictMode is off. - llvm::SmallDenseSet strictModeFilteredOps; + /// A subset of `ops`: ops that were erased during the rewrite. + DenseSet erasedOps; }; - } // namespace -/// Performs the specified rewrites on `ops` while also trying to fold these ops -/// as well as any other ops that were in turn created due to these rewrite -/// patterns. Any pre-existing ops outside of `ops` remain completely -/// unmodified if `strictMode` is true. If `strictMode` is false, other -/// operations that use results of rewritten ops or supply operands to such ops -/// are in turn simplified; any other ops still remain unmodified (i.e., -/// regardless of `strictMode`). Note that ops in `ops` could be erased as a -/// result of folding, becoming dead, or via pattern rewrites. Returns true if -/// at all any changes happened. -// Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op -// or GreedyPatternRewriteDriver::simplify, this method just iterates until -// the worklist is empty. As our objective is to keep simplification "local", -// there is no strong rationale to re-add all operations into the worklist and -// rerun until an iteration changes nothing. If more widereaching simplification -// is desired, GreedyPatternRewriteDriver should be used. -bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops) { - if (strictMode) { - strictModeFilteredOps.clear(); +MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, ArrayRef ops) + : GreedyPatternRewriteDriver(ctx, patterns, config, + /*scope=*/{}), + ops(ops.begin(), ops.end()) { + if (config.strictMode != GreedyRewriteConfig::Strictness::AnyOp) strictModeFilteredOps.insert(ops.begin(), ops.end()); - } +} - bool changed = false; - worklist.clear(); - worklistMap.clear(); +void MultiOpPatternRewriteDriver::notifyOperationRemoved(Operation *op) { + if (ops.contains(op)) + erasedOps.insert(op); +} + +LogicalResult +MultiOpPatternRewriteDriver::simplify(bool *changed, + DenseSet *erased) && { + // Populate the initial worklist. for (Operation *op : ops) addToWorklist(op); - // These are scratch vectors used in the folding loop below. - SmallVector originalOperands, resultValues; - int64_t numRewrites = 0; - while (!worklist.empty() && - (numRewrites < config.maxNumRewrites || - config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { - Operation *op = popFromWorklist(); - - // Nulls get added to the worklist when operations are removed, ignore - // them. - if (op == nullptr) - continue; - - assert((!strictMode || strictModeFilteredOps.contains(op)) && - "unexpected op was inserted under strict mode"); - - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - notifyOperationRemoved(op); - op->erase(); - changed = true; - continue; - } - - // Collects all the operands and result uses of the given `op` into work - // list. Also remove `op` and nested ops from worklist. - originalOperands.assign(op->operand_begin(), op->operand_end()); - auto preReplaceAction = [&](Operation *op) { - // Add the operands to the worklist for visitation. - addOperandsToWorklist(originalOperands); - - // Add all the users of the result to the worklist so we make sure - // to revisit them. - for (Value result : op->getResults()) { - for (Operation *userOp : result.getUsers()) - addToWorklist(userOp); - } - - notifyOperationRemoved(op); - }; - - // Add the given operation generated by the folder to the worklist. - auto processGeneratedConstants = [this](Operation *op) { - notifyOperationInserted(op); - }; - - // Try to fold this op. - bool inPlaceUpdate; - if (succeeded(folder.tryToFold(op, processGeneratedConstants, - preReplaceAction, &inPlaceUpdate))) { - changed = true; - if (!inPlaceUpdate) { - // Op has been erased. - continue; - } - } - - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do - // here. - if (succeeded(matcher.matchAndRewrite(op, *this))) { - changed = true; - ++numRewrites; - } - } + // Process ops on the worklist. + bool result = processWorklist(); + if (changed) + *changed = result; + if (erased) + *erased = std::move(erasedOps); - return changed; + // The rewrite converged if the worklist is empty. + return success(worklist.empty()); } -/// Rewrites only `op` using the supplied canonicalization patterns and -/// folding. `erased` is set to true if the op is erased as a result of being -/// folded, replaced, or dead. LogicalResult mlir::applyOpPatternsAndFold( - Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { + ArrayRef ops, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config, bool *changed, bool *allOpsErased) { + if (ops.empty()) + return success(); + + // Ops that are part of `ops` and were erased during simplification. + DenseSet erased; + // Start the pattern driver. - GreedyRewriteConfig config; - OpPatternRewriteDriver driver(op->getContext(), patterns); - bool opErased; - LogicalResult converged = - driver.simplifyLocally(op, config.maxNumRewrites, opErased); - if (erased) - *erased = opErased; + MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, + config, ops); + LogicalResult converged = std::move(driver).simplify(changed, &erased); LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after " << config.maxNumRewrites << " rewrites"; }); - return converged; -} -bool mlir::applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict) { - if (ops.empty()) - return false; - - // Start the pattern driver. - MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strict); - return driver.simplifyLocally(ops); + if (allOpsErased) + *allOpsErased = ops.size() == erased.size(); + return converged; } diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -132,7 +132,9 @@ AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); } } - (void)applyOpPatternsAndFold(copyOps, std::move(patterns), /*strict=*/true); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(copyOps, std::move(patterns), config); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -266,8 +266,9 @@ // Check if these transformations introduce visiting of operations that // are not in the `ops` set (The new created ops are valid). An invalid // operation will trigger the assertion while processing. - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), - /*strict=*/true); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config); } private: