diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -33,6 +33,11 @@ public: OperationFolder(MLIRContext *ctx) : interfaces(ctx) {} + /// Scan the specified region for constants that can be used in folding, + /// moving them to the entry block and adding them to our known-constants + /// table. + void processExistingConstants(Region ®ion); + /// Tries to perform folding on the given `op`, including unifying /// deduplicated constants. If successful, replaces `op`'s uses with /// folded results, and returns success. `preReplaceAction` is invoked on `op` diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -35,26 +35,26 @@ /// before attempting to match any of the provided patterns. LogicalResult applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + bool useTopDownTraversal = false); /// Rewrite the regions of the specified operation, with a user-provided limit /// on iterations to attempt before reaching convergence. -LogicalResult -applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns, - unsigned maxIterations); +LogicalResult applyPatternsAndFoldGreedily( + Operation *op, const FrozenRewritePatternSet &patterns, + unsigned maxIterations, bool useTopDownTraversal = false); /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + bool useTopDownTraversal = false); /// Rewrite the given regions, with a user-provided limit on iterations to /// attempt before reaching convergence. -LogicalResult -applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns, - unsigned maxIterations); +LogicalResult applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + unsigned maxIterations, bool useTopDownTraversal = false); /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -84,6 +84,81 @@ // OperationFolder //===----------------------------------------------------------------------===// +/// Scan the specified region for constants that can be used in folding, +/// moving them to the entry block and adding them to our known-constants +/// table. +void OperationFolder::processExistingConstants(Region ®ion) { + if (region.empty()) + return; + + // March the constant insertion point forward, moving all constants to the + // top of the block, but keeping them in their order of discovery. + Region *insertRegion = getInsertionRegion(interfaces, ®ion.front()); + auto &uniquedConstants = foldScopes[insertRegion]; + + Block &insertBlock = insertRegion->front(); + Block::iterator constantIterator = insertBlock.begin(); + + // Process each constant that we discover in this region. + auto processConstant = [&](Operation *op, Attribute value) { + // Check to see if we already have an instance of this constant. + Operation *&constOp = uniquedConstants[std::make_tuple( + op->getDialect(), value, op->getResult(0).getType())]; + + // If we already have an instance of this constant, CSE/delete this one as + // we go. + if (constOp) { + if (constantIterator == Block::iterator(op)) + ++constantIterator; // Don't invalidate our iterator when scanning. + op->getResult(0).replaceAllUsesWith(constOp->getResult(0)); + op->erase(); + return; + } + + // Otherwise, remember that we have this constant. + constOp = op; + referencedDialects[op].push_back(op->getDialect()); + + // If the constant isn't already at the insertion point then move it up. + if (constantIterator == insertBlock.end() || &*constantIterator != op) + op->moveBefore(&insertBlock, constantIterator); + else + ++constantIterator; // It was pointing at the constant. + }; + + SmallVector isolatedOps; + region.walk([&](Operation *op) { + // If this is a constant, process it. + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + processConstant(op, value); + // We may have deleted the operation, don't check it for regions. + return WalkResult::skip(); + } + + // If the operation has regions and is isolated, don't recurse into it. + if (op->getNumRegions() != 0) { + auto hasDifferentInsertRegion = [&](Region ®ion) { + return !region.empty() && + getInsertionRegion(interfaces, ®ion.front()) != insertRegion; + }; + if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) { + isolatedOps.push_back(op); + return WalkResult::skip(); + } + } + + // Otherwise keep going. + return WalkResult::advance(); + }); + + // Process regions in any isolated ops separately. + for (Operation *isolated : isolatedOps) { + for (Region ®ion : isolated->getRegions()) + processExistingConstants(region); + } +} + LogicalResult OperationFolder::tryToFold( Operation *op, function_ref processGeneratedConstants, function_ref preReplaceAction, bool *inPlaceUpdate) { @@ -262,19 +337,19 @@ Attribute value, Type type, Location loc) { // Check if an existing mapping already exists. auto constKey = std::make_tuple(dialect, value, type); - auto *&constInst = uniquedConstants[constKey]; - if (constInst) - return constInst; + auto *&constOp = uniquedConstants[constKey]; + if (constOp) + return constOp; // If one doesn't exist, try to materialize one. - if (!(constInst = materializeConstant(dialect, builder, value, type, loc))) + if (!(constOp = materializeConstant(dialect, builder, value, type, loc))) return nullptr; // Check to see if the generated constant is in the expected dialect. - auto *newDialect = constInst->getDialect(); + auto *newDialect = constOp->getDialect(); if (newDialect == dialect) { - referencedDialects[constInst].push_back(dialect); - return constInst; + referencedDialects[constOp].push_back(dialect); + return constOp; } // If it isn't, then we also need to make sure that the mapping for the new @@ -284,13 +359,13 @@ // If an existing operation in the new dialect already exists, delete the // materialized operation in favor of the existing one. if (auto *existingOp = uniquedConstants.lookup(newKey)) { - constInst->erase(); + constOp->erase(); referencedDialects[existingOp].push_back(dialect); - return constInst = existingOp; + return constOp = existingOp; } // Otherwise, update the new dialect to the materialized operation. - referencedDialects[constInst].assign({dialect, newDialect}); - auto newIt = uniquedConstants.insert({newKey, constInst}); + referencedDialects[constOp].assign({dialect, newDialect}); + auto newIt = uniquedConstants.insert({newKey, constOp}); return newIt.first->second; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,8 +37,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - const FrozenRewritePatternSet &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) { + const FrozenRewritePatternSet &patterns, + bool useTopDownTraversal) + : PatternRewriter(ctx), matcher(patterns), folder(ctx), + useTopDownTraversal(useTopDownTraversal) { worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -134,6 +136,9 @@ /// Non-pattern based folder for operations. OperationFolder folder; + + // Whether to use top-down or bottom-up traversal order. + bool useTopDownTraversal; }; } // end anonymous namespace @@ -141,15 +146,36 @@ /// if the rewrite converges in `maxIterations`. bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, int maxIterations) { - // Add the given operation to the worklist. - auto collectOps = [this](Operation *op) { addToWorklist(op); }; + // For maximum compatibility with existing passes, do not process existing + // constants unless we're performing a top-down traversal. + // TODO: This is just for compatibility with older MLIR, remove this. + if (useTopDownTraversal) { + // Perform a prepass over the IR to discover constants. + for (auto ®ion : regions) + folder.processExistingConstants(region); + } bool changed = false; - int i = 0; + int iteration = 0; do { - // Add all nested operations to the worklist. + worklist.clear(); + worklistMap.clear(); + + // Add all nested operations to the worklist in preorder. for (auto ®ion : regions) - region.walk(collectOps); + if (useTopDownTraversal) + region.walk( + [this](Operation *op) { worklist.push_back(op); }); + else + region.walk([this](Operation *op) { addToWorklist(op); }); + + if (useTopDownTraversal) { + // Reverse the list so our pop-back loop processes them in-order. + std::reverse(worklist.begin(), worklist.end()); + // Remember the reverse index. + for (unsigned i = 0, e = worklist.size(); i != e; ++i) + worklistMap[worklist[i]] = i; + } // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; @@ -187,6 +213,9 @@ notifyOperationRemoved(op); }; + // Add the given operation to the worklist. + auto collectOps = [this](Operation *op) { addToWorklist(op); }; + // Try to fold this op. bool inPlaceUpdate; if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, @@ -197,47 +226,50 @@ } // Try to match one of the patterns. The rewriter is automatically - // notified of any necessary changes, so there is nothing else to do here. + // notified of any necessary changes, so there is nothing else to do + // here. changed |= succeeded(matcher.matchAndRewrite(op, *this)); } - // After applying patterns, make sure that the CFG of each of the regions is - // kept up to date. + // After applying patterns, make sure that the CFG of each of the regions + // is kept up to date. changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed && ++i < maxIterations); + } while (changed && ++iteration < maxIterations); + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; } -/// Rewrite the regions of the specified operation, which must be isolated from -/// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return success if no more patterns can be matched -/// in the result operation regions. Note: This does not apply patterns to the -/// top-level operation itself. +/// Rewrite the regions of the specified operation, which must be isolated +/// from above, by repeatedly applying the highest benefit patterns in a +/// greedy work-list driven manner. Return success if no more patterns can be +/// matched in the result operation regions. Note: This does not apply +/// patterns to the top-level operation itself. /// LogicalResult -mlir::applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns) { - return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations); -} -LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, const FrozenRewritePatternSet &patterns, - unsigned maxIterations) { - return applyPatternsAndFoldGreedily(op->getRegions(), patterns, - maxIterations); + bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations, + useTopDownTraversal); } -/// Rewrite the given regions, which must be isolated from above. -LogicalResult -mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns) { - return applyPatternsAndFoldGreedily(regions, patterns, - maxPatternMatchIterations); +LogicalResult mlir::applyPatternsAndFoldGreedily( + Operation *op, const FrozenRewritePatternSet &patterns, + unsigned maxIterations, bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations, + useTopDownTraversal); } +/// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, const FrozenRewritePatternSet &patterns, - unsigned maxIterations) { + bool useTopDownTraversal) { + return applyPatternsAndFoldGreedily( + regions, patterns, maxPatternMatchIterations, useTopDownTraversal); +} +LogicalResult mlir::applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + unsigned maxIterations, bool useTopDownTraversal) { if (regions.empty()) return success(); @@ -252,7 +284,8 @@ "patterns can only be applied to operations IsolatedFromAbove"); // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, + useTopDownTraversal); bool converged = driver.simplify(regions, maxIterations); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " @@ -266,8 +299,9 @@ //===----------------------------------------------------------------------===// namespace { -/// This is a simple driver for the PatternMatcher to apply patterns and perform -/// folding on a single op. It repeatedly applies locally optimal patterns. +/// This is a simple driver for the PatternMatcher to apply patterns and +/// perform folding on a single op. It repeatedly applies locally optimal +/// patterns. class OpPatternRewriteDriver : public PatternRewriter { public: explicit OpPatternRewriteDriver(MLIRContext *ctx, @@ -280,20 +314,20 @@ /// Performs the rewrites and folding only on `op`. The simplification /// converges if the op is erased as a result of being folded, replaced, or /// dead, or no more changes happen in an iteration. Returns success if the - /// rewrite converges in `maxIterations`. `erased` is set to true if `op` gets - /// erased. + /// rewrite converges in `maxIterations`. `erased` is set to true if `op` + /// gets erased. LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased); // These are hooks implemented for PatternRewriter. protected: - /// If an operation is about to be removed, mark it so that we can let clients - /// know. + /// If an operation is about to be removed, mark it so that we can let + /// clients know. void notifyOperationRemoved(Operation *op) override { opErasedViaPatternRewrites = true; } - // When a root is going to be replaced, its removal will be notified as well. - // So there is nothing to do here. + // When a root is going to be replaced, its removal will be notified as + // well. So there is nothing to do here. void notifyRootReplaced(Operation *op) override {} private: diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -336,7 +336,7 @@ return } // CHECK-LABEL: func @aligned_promote_fill -// CHECK: %[[cf:.*]] = constant {{.*}} : f32 +// CHECK: %[[cf:.*]] = constant 1.0{{.*}} : f32 // CHECK: %[[s0:.*]] = memref.subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref // CHECK: %[[a0:.*]] = memref.alloc({{%.*}}) {alignment = 32 : i64} : memref // CHECK: %[[v0:.*]] = memref.view %[[a0]][{{.*}}][{{%.*}}, {{%.*}}] : memref to memref diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -662,7 +662,7 @@ // // CHECK-LABEL: func @lowered_affine_ceildiv func @lowered_affine_ceildiv() -> (index, index) { -// CHECK-NEXT: %c-1 = constant -1 : index +// CHECK-DAG: %c-1 = constant -1 : index %c-43 = constant -43 : index %c42 = constant 42 : index %c0 = constant 0 : index @@ -675,7 +675,7 @@ %5 = subi %c0, %4 : index %6 = addi %4, %c1 : index %7 = select %0, %5, %6 : index -// CHECK-NEXT: %c2 = constant 2 : index +// CHECK-DAG: %c2 = constant 2 : index %c43 = constant 43 : index %c42_0 = constant 42 : index %c0_1 = constant 0 : index