diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -21,6 +21,8 @@ /// This class allows control over how the GreedyPatternRewriteDriver works. class GreedyRewriteConfig { public: + enum class Strictness { AnyOp, ExistingAndNewOps, ExistingOps }; + /// This specifies the order of initial traversal that populates the rewriters /// worklist. When set to true, it walks the operations top-down, which is /// generally more efficient in compile time. When set to false, its initial @@ -42,6 +44,8 @@ int64_t maxNumRewrites = kNoLimit; static constexpr int64_t kNoLimit = -1; + + Strictness strictMode = Strictness::AnyOp; }; //===----------------------------------------------------------------------===// @@ -75,28 +79,28 @@ return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config); } -/// Applies the specified patterns on `op` alone while also trying to fold it, -/// by selecting the highest benefits patterns in a greedy manner. Returns -/// success if no more patterns can be matched. `erased` is set to true if `op` -/// was folded away or erased as a result of becoming dead. Note: This does not -/// apply any patterns recursively to the regions of `op`. -LogicalResult applyOpPatternsAndFold(Operation *op, - const FrozenRewritePatternSet &patterns, - bool *erased = nullptr); - /// Applies the specified rewrite patterns on `ops` while also trying to fold -/// these ops as well as any other ops that were in turn created due to such -/// rewrites. Furthermore, any pre-existing ops in the IR outside of `ops` -/// remain completely unmodified if `strict` is set to true. If `strict` is -/// false, other operations that use results of rewritten ops or supply operands -/// to such ops are in turn simplified; any other ops still remain unmodified -/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a -/// result of folding, becoming dead, or via pattern rewrites. If more far -/// reaching simplification is desired, applyPatternsAndFoldGreedily should be -/// used. Returns true if at all any IR was rewritten. -bool applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict); +/// these ops. +/// +/// `strictMode` controls whether other ops should be modified or not. +/// - ExistingOps: Only ops in `ops` are modified. Pre-existing ops remain +/// entirely unmodified. +/// - ExistingAndNewOps: Only ops in `ops` and newly created ops are modified. +/// Pre-existing ops remain entirely unmodified. +/// - AnyOp: In addition to ops in `ops` and newly created ops, other operations +/// that use results of rewritten ops or supply operands to such ops are in +/// turn simplified. +/// +/// Note that ops in `ops` could be erased as a result of folding, becoming +/// dead, or via pattern rewrites. `allOpsErased` is set to true if all ops in +/// `ops` were erased. `changed` is set to true if the IR is modified at all. +/// If more far reaching simplification is desired, applyPatternsAndFoldGreedily +/// should be used. Returns true if at all any IR was rewritten. +LogicalResult +applyOpPatternsAndFold(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr, bool *allOpsErased = nullptr); } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -238,5 +238,7 @@ AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(copyOps, frozenPatterns, /*strict=*/true); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(copyOps, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -105,5 +105,7 @@ if (isa(op)) opsToSimplify.push_back(op); }); - (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, /*strict=*/true); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -321,8 +321,11 @@ // Simplify/canonicalize the affine.for. RewritePatternSet patterns(res.getContext()); AffineForOp::getCanonicalizationPatterns(patterns, res.getContext()); - bool erased; - (void)applyOpPatternsAndFold(res, std::move(patterns), &erased); + bool erased, changed; + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingOps; + (void)applyOpPatternsAndFold({res}, std::move(patterns), config, + &changed, &erased); if (!erased && !prologue) prologue = res; diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -413,9 +413,12 @@ // in which case we return with `folded` being set. RewritePatternSet patterns(ifOp.getContext()); AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); - bool erased; + bool erased, changed; FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingOps; + (void)applyOpPatternsAndFold({ifOp}, frozenPatterns, config, &changed, + &erased); if (erased) { if (folded) *folded = true; diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -60,10 +60,13 @@ // matching in above iteration. Besides, erase op not-in-range may end up in // invalid module, so `applyOpPatternsAndFold` should come before that // transform. - for (Operation *op : opsInRange) + for (Operation *op : opsInRange) { // `applyOpPatternsAndFold` returns whether the op is convered. Omit it // because we don't have expectation this reduction will be success or not. - (void)applyOpPatternsAndFold(op, patterns); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingOps; + (void)applyOpPatternsAndFold({op}, patterns, config); + } if (eraseOpNotInRange) for (Operation *op : opsNotInRange) { diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,15 +37,14 @@ public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config); + const GreedyRewriteConfig &config, + MutableArrayRef regions); /// Simplify the operations within the given regions. - bool simplify(MutableArrayRef regions); + bool simplify() &&; - /// Add the given operation to the worklist. Parent ops may or may not be - /// added to the worklist, depending on the type of rewrite driver. By - /// default, parent ops are added. - virtual void addToWorklist(Operation *op); + /// Add the given operation to the worklist. + void addToWorklist(Operation *op); /// Pop the next operation from the worklist. Operation *popFromWorklist(); @@ -58,8 +57,9 @@ void finalizeRootUpdate(Operation *op) override; protected: - /// Add the given operation to the worklist. - void addSingleOpToWorklist(Operation *op); + explicit GreedyPatternRewriteDriver(MLIRContext *ctx, + const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config); // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. @@ -101,14 +101,17 @@ /// Non-pattern based folder for operations. OperationFolder folder; -protected: /// Configuration information for how to simplify. - GreedyRewriteConfig config; + const GreedyRewriteConfig config; + + /// The list of ops we are restricting our rewrites to if `strictMode` is on. + /// These include the supplied set of ops as well as new ops created while + /// rewriting those ops. This set is not maintained when strictMode is off. + llvm::SmallDenseSet strictModeFilteredOps; private: - /// Only ops within this scope are simplified. This is set at the beginning - /// of `simplify()` to the current scope the rewriter operates on. - DenseSet scope; + /// Only ops within this scope are simplified. + const MutableArrayRef regions; #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -119,18 +122,28 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + const GreedyRewriteConfig &config, MutableArrayRef regions) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), + regions(regions) { + for (auto ®ion : regions) { + if (config.strictMode != GreedyRewriteConfig::Strictness::AnyOp) { + region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); + } + } + worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) { - for (Region &r : regions) - scope.insert(&r); +GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( + MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config) + : GreedyPatternRewriteDriver(ctx, patterns, config, + MutableArrayRef()) {} +bool GreedyPatternRewriteDriver::simplify() && { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -172,7 +185,7 @@ if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. - for (auto ®ion : regions) { + for (Region ®ion : regions) { region.walk([&](Operation *op) { if (!insertKnownConstant(op)) addToWorklist(op); @@ -180,7 +193,7 @@ } } else { // Add all nested operations to the worklist in preorder. - for (auto ®ion : regions) { + for (Region ®ion : regions) { region.walk([&](Operation *op) { if (!insertKnownConstant(op)) { worklist.push_back(op); @@ -318,14 +331,26 @@ } void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { + auto addOp = [&](Operation *op) { + // Check to see if the worklist already contains this op. + if (worklistMap.count(op)) + return; + + if (config.strictMode == GreedyRewriteConfig::Strictness::AnyOp || + strictModeFilteredOps.contains(op)) { + worklistMap[op] = worklist.size(); + worklist.push_back(op); + } + }; + // Gather potential ancestors while looking for a "scope" parent region. SmallVector ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (scope.contains(region)) { + if (llvm::any_of(regions, [&](Region &r) { return &r == region; })) { // All gathered ops are in fact ancestors. for (Operation *op : ancestors) - addSingleOpToWorklist(op); + addOp(op); break; } @@ -337,15 +362,6 @@ } } -void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { - // Check to see if the worklist already contains this op. - if (worklistMap.count(op)) - return; - - worklistMap[op] = worklist.size(); - worklist.push_back(op); -} - Operation *GreedyPatternRewriteDriver::popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); @@ -370,6 +386,8 @@ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); + if (config.strictMode == GreedyRewriteConfig::Strictness::ExistingAndNewOps) + strictModeFilteredOps.insert(op); addToWorklist(op); } @@ -396,6 +414,8 @@ } void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { + if (config.strictMode != GreedyRewriteConfig::Strictness::AnyOp) + strictModeFilteredOps.erase(op); addOperandsToWorklist(op->getOperands()); op->walk([this](Operation *operation) { removeFromWorklist(operation); @@ -456,8 +476,9 @@ "patterns can only be applied to operations IsolatedFromAbove"); // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); - bool converged = driver.simplify(regions); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config, + regions); + bool converged = std::move(driver).simplify(); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; @@ -465,104 +486,6 @@ return success(converged); } -//===----------------------------------------------------------------------===// -// OpPatternRewriteDriver -//===----------------------------------------------------------------------===// - -namespace { -/// This is a simple driver for the PatternMatcher to apply patterns and perform -/// folding on a single op. It repeatedly applies locally optimal patterns. -class OpPatternRewriteDriver : public PatternRewriter { -public: - explicit OpPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) { - // Apply a simple cost model based solely on pattern benefit. - matcher.applyDefaultCostModel(); - } - - LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites, - bool &erased); - - // These are hooks implemented for PatternRewriter. -protected: - /// If an operation is about to be removed, mark it so that we can let clients - /// know. - void notifyOperationRemoved(Operation *op) override { - opErasedViaPatternRewrites = true; - } - - // When a root is going to be replaced, its removal will be notified as well. - // So there is nothing to do here. - void notifyRootReplaced(Operation *op, ValueRange replacement) override {} - -private: - /// The low-level pattern applicator. - PatternApplicator matcher; - - /// Non-pattern based folder for operations. - OperationFolder folder; - - /// Set to true if the operation has been erased via pattern rewrites. - bool opErasedViaPatternRewrites = false; -}; - -} // namespace - -/// Performs the rewrites and folding only on `op`. The simplification -/// converges if the op is erased as a result of being folded, replaced, or -/// becoming dead, or no more changes happen in an iteration. Returns success if -/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op` -/// gets erased. -LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, - int64_t maxNumRewrites, - bool &erased) { - bool changed = false; - erased = false; - opErasedViaPatternRewrites = false; - int64_t numRewrites = 0; - // Iterate until convergence or until maxNumRewrites. Deletion of the op as - // a result of being dead or folded is convergence. - do { - if (numRewrites >= maxNumRewrites && - maxNumRewrites != GreedyRewriteConfig::kNoLimit) - break; - - changed = false; - - // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { - op->erase(); - erased = true; - return success(); - } - - // Try to fold this op. - bool inPlaceUpdate; - if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, - /*preReplaceAction=*/nullptr, - &inPlaceUpdate))) { - changed = true; - if (!inPlaceUpdate) { - erased = true; - return success(); - } - } - - // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do here. - if (succeeded(matcher.matchAndRewrite(op, *this))) { - changed = true; - ++numRewrites; - } - if ((erased = opErasedViaPatternRewrites)) - return success(); - } while (changed); - - // Whether the rewrite converges, i.e. wasn't changed in the last iteration. - return failure(changed); -} - //===----------------------------------------------------------------------===// // MultiOpPatternRewriteDriver //===----------------------------------------------------------------------===// @@ -578,42 +501,18 @@ public: explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - bool strict) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), - strictMode(strict) {} - - bool simplifyLocally(ArrayRef op); - - void addToWorklist(Operation *op) override { - if (!strictMode || strictModeFilteredOps.contains(op)) - GreedyPatternRewriteDriver::addSingleOpToWorklist(op); + const GreedyRewriteConfig &config, + ArrayRef ops = {}) + : GreedyPatternRewriteDriver(ctx, patterns, config), ops(ops) { + if (config.strictMode != GreedyRewriteConfig::Strictness::AnyOp) + strictModeFilteredOps.insert(ops.begin(), ops.end()); } -private: - void notifyOperationInserted(Operation *op) override { - if (strictMode) - strictModeFilteredOps.insert(op); - GreedyPatternRewriteDriver::notifyOperationInserted(op); - } - - void notifyOperationRemoved(Operation *op) override { - GreedyPatternRewriteDriver::notifyOperationRemoved(op); - if (strictMode) - strictModeFilteredOps.erase(op); - } + LogicalResult simplifyLocally(bool *changed = nullptr, + DenseSet *erased = nullptr) &&; - /// If `strictMode` is true, any pre-existing ops outside of - /// `strictModeFilteredOps` remain completely untouched by the rewrite driver. - /// If `strictMode` is false, operations that use results of (or supply - /// operands to) any rewritten ops stemming from the simplification of the - /// provided ops are in turn simplified; any other ops still remain untouched - /// (i.e., regardless of `strictMode`). - bool strictMode = false; - - /// The list of ops we are restricting our rewrites to if `strictMode` is on. - /// These include the supplied set of ops as well as new ops created while - /// rewriting those ops. This set is not maintained when strictMode is off. - llvm::SmallDenseSet strictModeFilteredOps; +private: + const ArrayRef ops; }; } // namespace @@ -633,13 +532,11 @@ // there is no strong rationale to re-add all operations into the worklist and // rerun until an iteration changes nothing. If more widereaching simplification // is desired, GreedyPatternRewriteDriver should be used. -bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef ops) { - if (strictMode) { - strictModeFilteredOps.clear(); - strictModeFilteredOps.insert(ops.begin(), ops.end()); - } - - bool changed = false; +LogicalResult +MultiOpPatternRewriteDriver::simplifyLocally(bool *changed, + DenseSet *erased) && { + if (changed) + *changed = false; worklist.clear(); worklistMap.clear(); for (Operation *op : ops) @@ -658,14 +555,16 @@ if (op == nullptr) continue; - assert((!strictMode || strictModeFilteredOps.contains(op)) && + assert((config.strictMode == GreedyRewriteConfig::Strictness::AnyOp || + strictModeFilteredOps.contains(op)) && "unexpected op was inserted under strict mode"); // If the operation is trivially dead - remove it. if (isOpTriviallyDead(op)) { notifyOperationRemoved(op); op->erase(); - changed = true; + if (changed) + *changed = true; continue; } @@ -695,7 +594,8 @@ bool inPlaceUpdate; if (succeeded(folder.tryToFold(op, processGeneratedConstants, preReplaceAction, &inPlaceUpdate))) { - changed = true; + if (changed) + *changed = true; if (!inPlaceUpdate) { // Op has been erased. continue; @@ -706,42 +606,33 @@ // notified of any necessary changes, so there is nothing else to do // here. if (succeeded(matcher.matchAndRewrite(op, *this))) { - changed = true; + if (changed) + *changed = true; ++numRewrites; } } - return changed; + return success(worklist.empty()); } -/// Rewrites only `op` using the supplied canonicalization patterns and -/// folding. `erased` is set to true if the op is erased as a result of being -/// folded, replaced, or dead. LogicalResult mlir::applyOpPatternsAndFold( - Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { + ArrayRef ops, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config, bool *changed, bool *allOpsErased) { + if (ops.empty()) + return success(); + + DenseSet erased; // Start the pattern driver. - GreedyRewriteConfig config; - OpPatternRewriteDriver driver(op->getContext(), patterns); - bool opErased; - LogicalResult converged = - driver.simplifyLocally(op, config.maxNumRewrites, opErased); - if (erased) - *erased = opErased; + MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, + config, ops); + LogicalResult converged = std::move(driver).simplifyLocally(changed, &erased); LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after " << config.maxNumRewrites << " rewrites"; }); - return converged; -} -bool mlir::applyOpPatternsAndFold(ArrayRef ops, - const FrozenRewritePatternSet &patterns, - bool strict) { - if (ops.empty()) - return false; - - // Start the pattern driver. - MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strict); - return driver.simplifyLocally(ops); + if (allOpsErased) + *allOpsErased = + llvm::all_of(ops, [&](Operation *op) { return erased.contains(op); }); + return converged; } diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -132,7 +132,9 @@ AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); } } - (void)applyOpPatternsAndFold(copyOps, std::move(patterns), /*strict=*/true); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(copyOps, std::move(patterns), config); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -266,8 +266,9 @@ // Check if these transformations introduce visiting of operations that // are not in the `ops` set (The new created ops are valid). An invalid // operation will trigger the assertion while processing. - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), - /*strict=*/true); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteConfig::Strictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config); } private: