diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -63,6 +63,18 @@ /// Only ops within the scope are added to the worklist. If no scope is /// specified, the closest enclosing region is used as a scope. Region *scope = nullptr; + + /// Strict mode can restrict the ops that are added to the worklist during + /// the rewrite. + /// + /// * GreedyRewriteStrictness::AnyOp: No ops are excluded. + /// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing ops (that + /// were on the worklist at the very beginning) and newly created ops are + /// enqueued. All other ops are excluded. + /// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were + /// were on the worklist at the very beginning) enqueued. All other ops are + /// excluded. + GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; }; //===----------------------------------------------------------------------===// @@ -105,14 +117,8 @@ /// /// Newly created ops and other pre-existing ops that use results of rewritten /// ops or supply operands to such ops are simplified, unless such ops are -/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless -/// of `strictMode`). -/// -/// * GreedyRewriteStrictness::AnyOp: No ops are excluded. -/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly -/// created ops are simplified. All other ops are excluded. -/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are -/// simplified. All other ops are excluded. +/// excluded via `config.strictMode`. Any other ops remain unmodified (i.e., +/// regardless of `strictMode`). /// /// In addition to strictness, a region scope can be specified. Only ops within /// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`, @@ -130,23 +136,17 @@ LogicalResult applyOpPatternsAndFold(ArrayRef ops, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, GreedyRewriteConfig config = GreedyRewriteConfig(), bool *changed = nullptr, bool *allErased = nullptr); -/// Applies the specified patterns on `op` alone while also trying to fold it, -/// by selecting the highest benefits patterns in a greedy manner. Returns -/// success if no more patterns can be matched. `erased` is set to true if `op` -/// was folded away or erased as a result of becoming dead. -/// -/// Returns success if the iterative process converged and no more patterns can -/// be matched. +/// Applies the specified patterns on `op` while also trying to fold it. +/// This function is a shortcut for the ArrayRef overload and +/// behaves the same way. inline LogicalResult applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig(), bool *erased = nullptr) { - return applyOpPatternsAndFold(ArrayRef(op), patterns, - GreedyRewriteStrictness::ExistingOps, config, + return applyOpPatternsAndFold(ArrayRef(op), patterns, config, /*changed=*/nullptr, erased); } diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -130,10 +130,10 @@ patterns.insert, SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. - if (failed( - applyOpPatternsAndFold(targets, frozenPatterns, - GreedyRewriteStrictness::ExistingAndNewOps))) { + if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { auto diag = emitDefiniteFailure() << "affine.min/max simplification did not converge"; return diag; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -239,6 +239,7 @@ AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(copyOps, frozenPatterns, - GreedyRewriteStrictness::ExistingAndNewOps); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(copyOps, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -105,6 +105,7 @@ if (isa(op)) opsToSimplify.push_back(op); }); - (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, - GreedyRewriteStrictness::ExistingAndNewOps); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -321,9 +321,10 @@ // Simplify/canonicalize the affine.for. RewritePatternSet patterns(res.getContext()); AffineForOp::getCanonicalizationPatterns(patterns, res.getContext()); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; bool erased; - (void)applyOpPatternsAndFold(res, std::move(patterns), - GreedyRewriteConfig(), &erased); + (void)applyOpPatternsAndFold(res, std::move(patterns), config, &erased); if (!erased && !prologue) prologue = res; if (!erased) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -413,10 +413,11 @@ // in which case we return with `folded` being set. RewritePatternSet patterns(ifOp.getContext()); AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); - bool erased; FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - (void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(), - &erased); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + bool erased; + (void)applyOpPatternsAndFold(ifOp, frozenPatterns, config, &erased); if (erased) { if (folded) *folded = true; diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -60,10 +60,13 @@ // matching in above iteration. Besides, erase op not-in-range may end up in // invalid module, so `applyOpPatternsAndFold` should come before that // transform. - for (Operation *op : opsInRange) + for (Operation *op : opsInRange) { // `applyOpPatternsAndFold` returns whether the op is convered. Omit it // because we don't have expectation this reduction will be success or not. - (void)applyOpPatternsAndFold(op, patterns); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + (void)applyOpPatternsAndFold(op, patterns, config); + } if (eraseOpNotInRange) for (Operation *op : opsNotInRange) { diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -59,7 +59,7 @@ protected: /// Add the given operation to the worklist. - virtual void addSingleOpToWorklist(Operation *op); + void addSingleOpToWorklist(Operation *op); // Implement the hook for inserting operations, and make sure that newly // inserted ops are added to the worklist for processing. @@ -102,6 +102,12 @@ /// Configuration information for how to simplify. const GreedyRewriteConfig config; + /// The list of ops we are restricting our rewrites to. These include the + /// supplied set of ops as well as new ops created while rewriting those ops + /// depending on `strictMode`. This set is not maintained when + /// `config.strictMode` is GreedyRewriteStrictness::AnyOp. + llvm::SmallDenseSet strictModeFilteredOps; + private: #ifndef NDEBUG /// A logger used to emit information during the application process. @@ -150,6 +156,12 @@ return false; }; + // Populate strict mode ops. + if (config.strictMode != GreedyRewriteStrictness::AnyOp) { + strictModeFilteredOps.clear(); + region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); + } + bool changed = false; int64_t iteration = 0; do { @@ -323,12 +335,15 @@ } void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { - // Check to see if the worklist already contains this op. - if (worklistMap.count(op)) - return; - - worklistMap[op] = worklist.size(); - worklist.push_back(op); + if (config.strictMode == GreedyRewriteStrictness::AnyOp || + strictModeFilteredOps.contains(op)) { + // Check to see if the worklist already contains this op. + if (worklistMap.count(op)) + return; + + worklistMap[op] = worklist.size(); + worklist.push_back(op); + } } Operation *GreedyPatternRewriteDriver::popFromWorklist() { @@ -355,6 +370,8 @@ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); + if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) + strictModeFilteredOps.insert(op); addToWorklist(op); } @@ -391,6 +408,9 @@ removeFromWorklist(operation); folder.notifyRemoval(operation); }); + + if (config.strictMode != GreedyRewriteStrictness::AnyOp) + strictModeFilteredOps.erase(op); } void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op, @@ -459,10 +479,10 @@ public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config, + const GreedyRewriteConfig &config, llvm::SmallDenseSet *survivingOps = nullptr) : GreedyPatternRewriteDriver(ctx, patterns, config), - strictMode(strictMode), survivingOps(survivingOps) {} + survivingOps(survivingOps) {} /// Performs the specified rewrites on `ops` while also trying to fold these /// ops. `strictMode` controls which other ops are simplified. Only ops @@ -476,38 +496,13 @@ LogicalResult simplifyLocally(ArrayRef op, bool *changed = nullptr) &&; -protected: - void addSingleOpToWorklist(Operation *op) override { - if (strictMode == GreedyRewriteStrictness::AnyOp || - strictModeFilteredOps.contains(op)) - GreedyPatternRewriteDriver::addSingleOpToWorklist(op); - } - private: - void notifyOperationInserted(Operation *op) override { - if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps) - strictModeFilteredOps.insert(op); - GreedyPatternRewriteDriver::notifyOperationInserted(op); - } - void notifyOperationRemoved(Operation *op) override { GreedyPatternRewriteDriver::notifyOperationRemoved(op); if (survivingOps) survivingOps->erase(op); - if (strictMode != GreedyRewriteStrictness::AnyOp) - strictModeFilteredOps.erase(op); } - /// `strictMode` control which ops are added to the worklist during - /// simplification. - const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; - - /// The list of ops we are restricting our rewrites to. These include the - /// supplied set of ops as well as new ops created while rewriting those ops - /// depending on `strictMode`. This set is not maintained when `strictMode` - /// is GreedyRewriteStrictness::AnyOp. - llvm::SmallDenseSet strictModeFilteredOps; - /// An optional set of ops that survived the rewrite. This set is populated /// at the beginning of `simplifyLocally` with the inititally provided list /// of ops. @@ -524,7 +519,7 @@ survivingOps->insert(ops.begin(), ops.end()); } - if (strictMode != GreedyRewriteStrictness::AnyOp) { + if (config.strictMode != GreedyRewriteStrictness::AnyOp) { strictModeFilteredOps.clear(); strictModeFilteredOps.insert(ops.begin(), ops.end()); } @@ -549,7 +544,7 @@ if (op == nullptr) continue; - assert((strictMode == GreedyRewriteStrictness::AnyOp || + assert((config.strictMode == GreedyRewriteStrictness::AnyOp || strictModeFilteredOps.contains(op)) && "unexpected op was inserted under strict mode"); @@ -637,8 +632,7 @@ LogicalResult mlir::applyOpPatternsAndFold( ArrayRef ops, const FrozenRewritePatternSet &patterns, - GreedyRewriteStrictness strictMode, GreedyRewriteConfig config, - bool *changed, bool *allErased) { + GreedyRewriteConfig config, bool *changed, bool *allErased) { if (ops.empty()) { if (changed) *changed = false; @@ -664,8 +658,7 @@ // Start the pattern driver. llvm::SmallDenseSet surviving; MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - strictMode, config, - allErased ? &surviving : nullptr); + config, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) *allErased = surviving.empty(); diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -132,8 +132,9 @@ AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); } } - (void)applyOpPatternsAndFold(copyOps, std::move(patterns), - GreedyRewriteStrictness::ExistingAndNewOps); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; + (void)applyOpPatternsAndFold(copyOps, std::move(patterns), config); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -266,13 +266,13 @@ } }); - GreedyRewriteStrictness mode; + GreedyRewriteConfig config; if (strictMode == "AnyOp") { - mode = GreedyRewriteStrictness::AnyOp; + config.strictMode = GreedyRewriteStrictness::AnyOp; } else if (strictMode == "ExistingAndNewOps") { - mode = GreedyRewriteStrictness::ExistingAndNewOps; + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; } else if (strictMode == "ExistingOps") { - mode = GreedyRewriteStrictness::ExistingOps; + config.strictMode = GreedyRewriteStrictness::ExistingOps; } else { llvm_unreachable("invalid strictness option"); } @@ -282,8 +282,8 @@ // operation will trigger the assertion while processing. bool changed = false; bool allErased = false; - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode, - GreedyRewriteConfig(), &changed, &allErased); + (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, + &changed, &allErased); Builder b(ctx); getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); getOperation()->setAttr("pattern_driver_all_erased",