diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -242,8 +242,8 @@ ### Greedy Pattern Rewrite Driver -This driver performs a post order traversal over the provided operations and -greedily applies the patterns that locally have the most benefit. The benefit of +This driver walks the provided operations and greedily applies the patterns that +locally have the most benefit. The benefit of a pattern is decided solely by the benefit specified on the pattern, and the relative order of the pattern within the pattern list (when two patterns have the same local benefit). Patterns are iteratively applied to operations until a @@ -252,5 +252,9 @@ `applyOpPatternsAndFold`. The latter of which only applies patterns to the provided operation, and will not traverse the IR. +The driver is configurable: you can choose to do a top-down initial traversal +(generally more efficient in compile time) or you can choose to do a post-order +"bottom-up" traversal (may match larger patterns with ambiguous pattern sets). + Note: This driver is the one used by the [canonicalization](Canonicalization.md) [pass](Passes.md#-canonicalize-canonicalize-operations) in MLIR. diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -33,6 +33,11 @@ public: OperationFolder(MLIRContext *ctx) : interfaces(ctx) {} + /// Scan the specified region for constants that can be used in folding, + /// moving them to the entry block (or any custom insertion location specified + /// by shouldMaterializeInto), and add them to our known-constants table. + void processExistingConstants(Region ®ion); + /// Tries to perform folding on the given `op`, including unifying /// deduplicated constants. If successful, replaces `op`'s uses with /// folded results, and returns success. `preReplaceAction` is invoked on `op` diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -25,36 +25,45 @@ /// Rewrite the regions of the specified operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy /// work-list driven manner. +/// /// This variant may stop after a predefined number of iterations, see the /// alternative below to provide a specific number of iterations before stopping /// in absence of convergence. +/// /// Return success if the iterative process converged and no more patterns can /// be matched in the result operation regions. +/// /// Note: This does not apply patterns to the top-level operation itself. /// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. +/// +/// You may choose the order of initial traversal with the `useTopDownTraversal` +/// boolean. When set to true, it walks the operations top-down, which is +/// generally more efficient in compile time. When set to false, its initial +/// traversal is post-order, which may match larger patterns when given an +/// ambiguous pattern set. LogicalResult applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + bool useTopDownTraversal = false); /// Rewrite the regions of the specified operation, with a user-provided limit /// on iterations to attempt before reaching convergence. -LogicalResult -applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns, - unsigned maxIterations); +LogicalResult applyPatternsAndFoldGreedily( + Operation *op, const FrozenRewritePatternSet &patterns, + unsigned maxIterations, bool useTopDownTraversal = false); /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + bool useTopDownTraversal = false); /// Rewrite the given regions, with a user-provided limit on iterations to /// attempt before reaching convergence. -LogicalResult -applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns, - unsigned maxIterations); +LogicalResult applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + unsigned maxIterations, bool useTopDownTraversal = false); /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -363,6 +363,11 @@ }]; let constructor = "mlir::createCanonicalizerPass()"; let dependentDialects = ["memref::MemRefDialect"]; + let options = [ + Option<"topDownProcessingEnabled", "enableTopDown", "bool", + /*default=*/"false", + "Enable top-down processing"> + ]; } def CSE : Pass<"cse"> { diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -32,7 +32,10 @@ return success(); } void runOnOperation() override { - (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); + (void)applyPatternsAndFoldGreedily( + getOperation()->getRegions(), patterns, + /*maxIterations=*/10, /*useTopDownTraversal=*/ + topDownProcessingEnabled); } FrozenRewritePatternSet patterns; diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -84,6 +84,81 @@ // OperationFolder //===----------------------------------------------------------------------===// +/// Scan the specified region for constants that can be used in folding, +/// moving them to the entry block (or any custom insertion location specified +/// by shouldMaterializeInto), and add them to our known-constants table. +void OperationFolder::processExistingConstants(Region ®ion) { + if (region.empty()) + return; + + // March the constant insertion point forward, moving all constants to the + // top of the block, but keeping them in their order of discovery. + Region *insertRegion = getInsertionRegion(interfaces, ®ion.front()); + auto &uniquedConstants = foldScopes[insertRegion]; + + Block &insertBlock = insertRegion->front(); + Block::iterator constantIterator = insertBlock.begin(); + + // Process each constant that we discover in this region. + auto processConstant = [&](Operation *op, Attribute value) { + // Check to see if we already have an instance of this constant. + Operation *&constOp = uniquedConstants[std::make_tuple( + op->getDialect(), value, op->getResult(0).getType())]; + + // If we already have an instance of this constant, CSE/delete this one as + // we go. + if (constOp) { + if (constantIterator == Block::iterator(op)) + ++constantIterator; // Don't invalidate our iterator when scanning. + op->getResult(0).replaceAllUsesWith(constOp->getResult(0)); + op->erase(); + return; + } + + // Otherwise, remember that we have this constant. + constOp = op; + referencedDialects[op].push_back(op->getDialect()); + + // If the constant isn't already at the insertion point then move it up. + if (constantIterator == insertBlock.end() || &*constantIterator != op) + op->moveBefore(&insertBlock, constantIterator); + else + ++constantIterator; // It was pointing at the constant. + }; + + SmallVector isolatedOps; + region.walk([&](Operation *op) { + // If this is a constant, process it. + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + processConstant(op, value); + // We may have deleted the operation, don't check it for regions. + return WalkResult::skip(); + } + + // If the operation has regions and is isolated, don't recurse into it. + if (op->getNumRegions() != 0) { + auto hasDifferentInsertRegion = [&](Region ®ion) { + return !region.empty() && + getInsertionRegion(interfaces, ®ion.front()) != insertRegion; + }; + if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) { + isolatedOps.push_back(op); + return WalkResult::skip(); + } + } + + // Otherwise keep going. + return WalkResult::advance(); + }); + + // Process regions in any isolated ops separately. + for (Operation *isolated : isolatedOps) { + for (Region ®ion : isolated->getRegions()) + processExistingConstants(region); + } +} + LogicalResult OperationFolder::tryToFold( Operation *op, function_ref processGeneratedConstants, function_ref preReplaceAction, bool *inPlaceUpdate) { @@ -262,19 +337,19 @@ Attribute value, Type type, Location loc) { // Check if an existing mapping already exists. auto constKey = std::make_tuple(dialect, value, type); - auto *&constInst = uniquedConstants[constKey]; - if (constInst) - return constInst; + auto *&constOp = uniquedConstants[constKey]; + if (constOp) + return constOp; // If one doesn't exist, try to materialize one. - if (!(constInst = materializeConstant(dialect, builder, value, type, loc))) + if (!(constOp = materializeConstant(dialect, builder, value, type, loc))) return nullptr; // Check to see if the generated constant is in the expected dialect. - auto *newDialect = constInst->getDialect(); + auto *newDialect = constOp->getDialect(); if (newDialect == dialect) { - referencedDialects[constInst].push_back(dialect); - return constInst; + referencedDialects[constOp].push_back(dialect); + return constOp; } // If it isn't, then we also need to make sure that the mapping for the new @@ -284,13 +359,13 @@ // If an existing operation in the new dialect already exists, delete the // materialized operation in favor of the existing one. if (auto *existingOp = uniquedConstants.lookup(newKey)) { - constInst->erase(); + constOp->erase(); referencedDialects[existingOp].push_back(dialect); - return constInst = existingOp; + return constOp = existingOp; } // Otherwise, update the new dialect to the materialized operation. - referencedDialects[constInst].assign({dialect, newDialect}); - auto newIt = uniquedConstants.insert({newKey, constInst}); + referencedDialects[constOp].assign({dialect, newDialect}); + auto newIt = uniquedConstants.insert({newKey, constOp}); return newIt.first->second; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,8 +37,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) { + const FrozenRewritePatternSet &patterns, + bool useTopDownTraversal) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), + useTopDownTraversal(useTopDownTraversal) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -134,6 +136,9 @@ /// Non-pattern based folder for operations. OperationFolder folder; + + // Whether to use top-down or bottom-up traversal order. + bool useTopDownTraversal; }; } // end anonymous namespace @@ -141,15 +146,36 @@ /// if the rewrite converges in `maxIterations`. bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, int maxIterations) { - // Add the given operation to the worklist. - auto collectOps = [this](Operation *op) { addToWorklist(op); }; + // For maximum compatibility with existing passes, do not process existing + // constants unless we're performing a top-down traversal. + // TODO: This is just for compatibility with older MLIR, remove this. + if (useTopDownTraversal) { + // Perform a prepass over the IR to discover constants. + for (auto ®ion : regions) + folder.processExistingConstants(region); + } bool changed = false; - int i = 0; + int iteration = 0; do { - // Add all nested operations to the worklist. + worklist.clear(); + worklistMap.clear(); + + // Add all nested operations to the worklist in preorder. for (auto ®ion : regions) - region.walk(collectOps); + if (useTopDownTraversal) + region.walk( + [this](Operation *op) { worklist.push_back(op); }); + else + region.walk([this](Operation *op) { addToWorklist(op); }); + + if (useTopDownTraversal) { + // Reverse the list so our pop-back loop processes them in-order. + std::reverse(worklist.begin(), worklist.end()); + // Remember the reverse index. + for (size_t i = 0, e = worklist.size(); i != e; ++i) + worklistMap[worklist[i]] = i; + } // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; @@ -187,6 +213,9 @@ notifyOperationRemoved(op); }; + // Add the given operation to the worklist. + auto collectOps = [this](Operation *op) { addToWorklist(op); }; + // Try to fold this op. bool inPlaceUpdate; if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, @@ -197,14 +226,16 @@ } // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do here. + // notified of any necessary changes, so there is nothing else to do + // here. changed |= succeeded(matcher.matchAndRewrite(op, *this)); } - // After applying patterns, make sure that the CFG of each of the regions is - // kept up to date. + // After applying patterns, make sure that the CFG of each of the regions + // is kept up to date. changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed && ++i < maxIterations); + } while (changed && ++iteration < maxIterations); + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; } @@ -216,28 +247,29 @@ /// top-level operation itself. /// LogicalResult -mlir::applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns) { - return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations); -} -LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternSet &patterns, - unsigned maxIterations) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns, - maxIterations); + bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations, + useTopDownTraversal); } -/// Rewrite the given regions, which must be isolated from above. -LogicalResult -mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns) { - return applyPatternsAndFoldGreedily(regions, patterns, - maxPatternMatchIterations); +LogicalResult mlir::applyPatternsAndFoldGreedily( + Operation *op, const FrozenRewritePatternSet &patterns, + unsigned maxIterations, bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations, + useTopDownTraversal); } +/// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternSet &patterns, - unsigned maxIterations) { + bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily( + regions, patterns, maxPatternMatchIterations, useTopDownTraversal); +} +LogicalResult mlir::applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + unsigned maxIterations, bool useTopDownTraversal) { if (regions.empty()) return success(); @@ -252,7 +284,8 @@ "patterns can only be applied to operations IsolatedFromAbove"); // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, + useTopDownTraversal); bool converged = driver.simplify(regions, maxIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " diff --git a/mlir/test/Transforms/canonicalize-td.mlir b/mlir/test/Transforms/canonicalize-td.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/canonicalize-td.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize{enableTopDown=true})' | FileCheck %s --check-prefix=TD +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' | FileCheck %s --check-prefix=BU + + +// CHECK-LABEL: func @default_insertion_position +func @default_insertion_position(%cond: i1) { + // Constant should be folded into the entry block. + + // BU: constant 2 + // BU-NEXT: scf.if + + // TD: constant 2 + // TD-NEXT: scf.if + scf.if %cond { + %0 = constant 1 : i32 + %2 = addi %0, %0 : i32 + "foo.yield"(%2) : (i32) -> () + } + return +} + +// This shows that we don't pull the constant out of the region because it +// wants to be the insertion point for the constant. +// CHECK-LABEL: func @custom_insertion_position +func @custom_insertion_position() { + // BU: test.one_region_op + // BU-NEXT: constant 2 + + // TD: test.one_region_op + // TD-NEXT: constant 2 + "test.one_region_op"() ({ + + %0 = constant 1 : i32 + %2 = addi %0, %0 : i32 + "foo.yield"(%2) : (i32) -> () + }) : () -> () + return +} +