diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -59,9 +59,9 @@ // applyPatternsGreedily //===----------------------------------------------------------------------===// -/// Rewrite the regions of the specified operation, which must be isolated from -/// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. +/// Rewrite ops in the given region, which must be isolated from above, by +/// repeatedly applying the highest benefit patterns in a greedy work-list +/// driven manner. /// /// This variant may stop after a predefined number of iterations, see the /// alternative below to provide a specific number of iterations before stopping @@ -76,14 +76,18 @@ /// /// You may configure several aspects of this with GreedyRewriteConfig. LogicalResult applyPatternsAndFoldGreedily( - MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig()); -/// Rewrite the given regions, which must be isolated from above. +/// Rewrite ops in all regions of the given op, which must be isolated from +/// above. inline LogicalResult applyPatternsAndFoldGreedily( Operation *op, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig()) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config); + bool failed = false; + for (Region ®ion : op->getRegions()) + failed |= applyPatternsAndFoldGreedily(region, patterns, config).failed(); + return failure(failed); } /// Applies the specified rewrite patterns on `ops` while also trying to fold diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1867,8 +1867,7 @@ // Use TopDownTraversal for compile time reasons GreedyRewriteConfig grc; grc.useTopDownTraversal = true; - (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), - grc); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -781,8 +781,7 @@ void ExpandStridedMetadataPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateExpandStridedMetadataPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), - std::move(patterns)); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } std::unique_ptr memref::createExpandStridedMetadataPass() { diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -605,8 +605,7 @@ void FoldMemRefAliasOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateFoldMemRefAliasOpPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), - std::move(patterns)); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } std::unique_ptr memref::createFoldMemRefAliasOpsPass() { diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -149,8 +149,7 @@ void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(), - std::move(patterns)))) + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } @@ -158,8 +157,7 @@ RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(), - std::move(patterns)))) + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -120,8 +120,7 @@ RewritePatternSet loweringPatterns(context); populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); - if (failed(applyPatternsAndFoldGreedily(op->getRegions(), - std::move(loweringPatterns)))) + if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) signalPassFailure(); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -40,10 +40,10 @@ explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, - const DenseSet &scope); + const Region &scope); - /// Simplify the operations within the given regions. - bool simplify(MutableArrayRef regions) &&; + /// Simplify the ops within the given region. + bool simplify(Region ®ion) &&; /// Add the given operation and its ancestors to the worklist. void addToWorklist(Operation *op); @@ -104,7 +104,7 @@ const GreedyRewriteConfig config; /// Only ops within this scope are simplified. - const DenseSet scope; + const Region &scope; private: #ifndef NDEBUG @@ -116,7 +116,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config, const DenseSet &scope) + const GreedyRewriteConfig &config, const Region &scope) : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config), scope(scope) { worklist.reserve(64); @@ -125,7 +125,7 @@ matcher.applyDefaultCostModel(); } -bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions) && { +bool GreedyPatternRewriteDriver::simplify(Region ®ion) && { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -167,15 +167,12 @@ if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. - for (auto ®ion : regions) { region.walk([&](Operation *op) { if (!insertKnownConstant(op)) addToWorklist(op); }); - } } else { // Add all nested operations to the worklist in preorder. - for (auto ®ion : regions) { region.walk([&](Operation *op) { if (!insertKnownConstant(op)) { worklist.push_back(op); @@ -183,7 +180,6 @@ } return WalkResult::skip(); }); - } // Reverse the list so our pop-back loop processes them in-order. std::reverse(worklist.begin(), worklist.end()); @@ -305,7 +301,7 @@ // After applying patterns, make sure that the CFG of each of the regions // is kept up to date. if (config.enableRegionSimplification) - changed |= succeeded(simplifyRegions(*this, regions)); + changed |= succeeded(simplifyRegions(*this, region)); } while (changed); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. @@ -317,7 +313,7 @@ SmallVector ancestors; ancestors.push_back(op); while (Region *region = op->getParentRegion()) { - if (scope.contains(region)) { + if (&scope == region) { // All gathered ops are in fact ancestors. for (Operation *op : ancestors) addSingleOpToWorklist(op); @@ -429,31 +425,19 @@ /// top-level operation itself. /// LogicalResult -mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, +mlir::applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config) { - if (regions.empty()) - return success(); - // The top-level operation must be known to be isolated from above to // prevent performing canonicalizations on operations defined at or above // the region containing 'op'. - auto regionIsIsolated = [](Region ®ion) { - return region.getParentOp()->hasTrait(); - }; - (void)regionIsIsolated; - assert(llvm::all_of(regions, regionIsIsolated) && + assert(region.getParentOp()->hasTrait() && "patterns can only be applied to operations IsolatedFromAbove"); - // Limit ops on the worklist to this scope. - DenseSet scope; - for (Region &r : regions) - scope.insert(&r); - // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config, - scope); - bool converged = std::move(driver).simplify(regions); + GreedyPatternRewriteDriver driver(region.getContext(), patterns, config, + region); + bool converged = std::move(driver).simplify(region); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " << config.maxIterations << " times\n"; @@ -476,7 +460,7 @@ public: explicit MultiOpPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const DenseSet &scope, GreedyRewriteStrictness strictMode, + const Region &scope, GreedyRewriteStrictness strictMode, llvm::SmallDenseSet *survivingOps = nullptr) : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope), strictMode(strictMode), survivingOps(survivingOps) {} @@ -680,10 +664,8 @@ // Start the pattern driver. llvm::SmallDenseSet surviving; - DenseSet scopeSet; - scopeSet.insert(scope); MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - scopeSet, strictMode, + *scope, strictMode, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplifyLocally(ops, changed); if (allErased) diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1633,8 +1633,7 @@ MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.add(context); - (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), - std::move(patterns)); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; } // namespace