diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -35,26 +35,26 @@ /// before attempting to match any of the provided patterns. LogicalResult applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternList &patterns); + const FrozenRewritePatternList &patterns, + bool useTopDownTraversal = true); /// Rewrite the regions of the specified operation, with a user-provided limit /// on iterations to attempt before reaching convergence. -LogicalResult -applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternList &patterns, - unsigned maxIterations); +LogicalResult applyPatternsAndFoldGreedily( + Operation *op, const FrozenRewritePatternList &patterns, + unsigned maxIterations, bool useTopDownTraversal = true); /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternList &patterns); + const FrozenRewritePatternList &patterns, + bool useTopDownTraversal = true); /// Rewrite the given regions, with a user-provided limit on iterations to /// attempt before reaching convergence. -LogicalResult -applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternList &patterns, - unsigned maxIterations); +LogicalResult applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternList &patterns, + unsigned maxIterations, bool useTopDownTraversal = true); /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1115,7 +1115,8 @@ Operation *op = getOperation(); OwningRewritePatternList patterns(op->getContext()); populateLinalgTensorOpsFusionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); + (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), + /*useTopDown=*/false); } }; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,8 +37,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternList &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) { + const FrozenRewritePatternList &patterns, + bool useTopDownTraversal) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), + useTopDownTraversal(useTopDownTraversal) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -134,6 +136,9 @@ /// Non-pattern based folder for operations. OperationFolder folder; + + // Whether to use top-down or bottom-up traversal order. + bool useTopDownTraversal; }; } // end anonymous namespace @@ -153,14 +158,19 @@ // Add all nested operations to the worklist in preorder. for (auto ®ion : regions) - region.walk( - [this](Operation *op) { worklist.push_back(op); }); - - // Reverse the list so our pop-back loop processes them in-order. - std::reverse(worklist.begin(), worklist.end()); - // Remember the reverse index. - for (unsigned i = 0, e = worklist.size(); i != e; ++i) - worklistMap[worklist[i]] = i; + if (useTopDownTraversal) + region.walk( + [this](Operation *op) { worklist.push_back(op); }); + else + region.walk([this](Operation *op) { addToWorklist(op); }); + + if (useTopDownTraversal) { + // Reverse the list so our pop-back loop processes them in-order. + std::reverse(worklist.begin(), worklist.end()); + // Remember the reverse index. + for (unsigned i = 0, e = worklist.size(); i != e; ++i) + worklistMap[worklist[i]] = i; + } // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; @@ -231,28 +241,29 @@ /// top-level operation itself. /// LogicalResult -mlir::applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternList &patterns) { - return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations); -} -LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternList &patterns, - unsigned maxIterations) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns, - maxIterations); + bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations, + useTopDownTraversal); } -/// Rewrite the given regions, which must be isolated from above. -LogicalResult -mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternList &patterns) { - return applyPatternsAndFoldGreedily(regions, patterns, - maxPatternMatchIterations); +LogicalResult mlir::applyPatternsAndFoldGreedily( + Operation *op, const FrozenRewritePatternList &patterns, + unsigned maxIterations, bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations, + useTopDownTraversal); } +/// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternList &patterns, - unsigned maxIterations) { + bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily( + regions, patterns, maxPatternMatchIterations, useTopDownTraversal); +} +LogicalResult mlir::applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternList &patterns, + unsigned maxIterations, bool useTopDownTraversal) { if (regions.empty()) return success(); @@ -267,7 +278,8 @@ "patterns can only be applied to operations IsolatedFromAbove"); // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, + useTopDownTraversal); bool converged = driver.simplify(regions, maxIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "