diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -23,7 +23,9 @@ namespace mlir { class AffineForOp; +class DominanceInfo; class GreedyRewriteConfig; +struct RewriteListener; /// Fusion mode to attempt. The default mode `Greedy` does both /// producer-consumer and sibling fusion. @@ -74,9 +76,25 @@ ArrayRef disabledPatterns = llvm::None, ArrayRef enabledPatterns = llvm::None); +/// Canonicalize all operations nested under the provided operation, using the +/// specified config, disabled and enabled patterns, and notify the provided +/// listener of rewrite events. +LogicalResult +canonicalizeOperations(Operation *op, const GreedyRewriteConfig &config, + ArrayRef disabledPatterns = llvm::None, + ArrayRef enabledPatterns = llvm::None, + RewriteListener *listener = nullptr); + /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); +/// Perform common subexpression elimination on all operations nested within the +/// provided operation. Optionally provide existing dominance info or a listener +/// to be notified when operations are replaced or erased. +LogicalResult +eliminateCommonSubexpressions(Operation *op, DominanceInfo *domInfo = nullptr, + RewriteListener *listener = nullptr); + /// Creates a loop fusion pass which fuses loops according to type of fusion /// specified in `fusionMode`. Buffers of size less than or equal to /// `localBufSizeThreshold` are promoted to memory space `fastMemorySpace`. diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -53,7 +53,7 @@ namespace { /// Simple common sub-expression elimination. -struct CSE : public CSEBase { +struct CSE { /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, @@ -61,6 +61,9 @@ using ScopedMapTy = llvm::ScopedHashTable; + CSE(DominanceInfo *domInfo, RewriteListener *listener) + : domInfo(domInfo), listener(listener) {} + /// Represents a single entry in the depth first traversal of a CFG. struct CFGStackNode { CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node) @@ -83,13 +86,31 @@ bool hasSSADominance); void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance); void simplifyRegion(ScopedMapTy &knownValues, Region ®ion); + /// Return the number of erased operations. + unsigned simplify(Operation *rootOp); - void runOnOperation() override; + /// Get the number of dead operations removed. + unsigned getNumDCE() { return numDCE; } + /// Get the number of redundant operations removed. + unsigned getNumCSE() { return numCSE; } private: /// Operations marked as dead and to be erased. std::vector opsToErase; - DominanceInfo *domInfo = nullptr; + /// The number of trivially dead operations found and removed. + unsigned numDCE = 0; + /// The number of redundant operations removed. + unsigned numCSE = 0; + + /// The dominance info to use. + DominanceInfo *domInfo; + /// An optional listener to notify of replaced or erased operations. + RewriteListener *listener; +}; + +/// Common subexpression elimination pass. +struct CSEPass : CSEBase { + void runOnOperation() override; }; } // namespace @@ -127,6 +148,8 @@ if (hasSSADominance) { // If the region has SSA dominance, then we are guaranteed to have not // visited any use of the current operation. + if (listener) + listener->notifyRootReplaced(op); op->replaceAllUsesWith(existing); opsToErase.push_back(op); } else { @@ -243,29 +266,53 @@ } } -void CSE::runOnOperation() { +unsigned CSE::simplify(Operation *rootOp) { /// A scoped hash table of defining operations within a region. ScopedMapTy knownValues; - domInfo = &getAnalysis(); - Operation *rootOp = getOperation(); - for (auto ®ion : rootOp->getRegions()) simplifyRegion(knownValues, region); - // If no operations were erased, then we mark all analyses as preserved. - if (opsToErase.empty()) - return markAllAnalysesPreserved(); - /// Erase any operations that were marked as dead during simplification. - for (auto *op : opsToErase) + for (auto *op : opsToErase) { + if (listener) + listener->notifyOperationRemoved(op); op->erase(); - opsToErase.clear(); + } + + return opsToErase.size(); +} + +void CSEPass::runOnOperation() { + CSE cse(&getAnalysis(), /*listener=*/nullptr); + unsigned numOpsRemoved = cse.simplify(getOperation()); + numDCE += cse.getNumDCE(); + numCSE += cse.getNumCSE(); + + // If no operations were erased, then we mark all analyses as preserved. + if (numOpsRemoved == 0) + return markAllAnalysesPreserved(); // We currently don't remove region operations, so mark dominance as // preserved. markAnalysesPreserved(); - domInfo = nullptr; } -std::unique_ptr mlir::createCSEPass() { return std::make_unique(); } +std::unique_ptr mlir::createCSEPass() { + return std::make_unique(); +} + +/// Run CSE on the provided operation +LogicalResult mlir::eliminateCommonSubexpressions(Operation *op, + DominanceInfo *domInfo, + RewriteListener *listener) { + Optional defaultDomInfo; + if (domInfo == nullptr) { + defaultDomInfo.emplace(op); + domInfo = &*defaultDomInfo; + } + + CSE cse(domInfo, listener); + cse.simplify(op); + return success(); +} diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -12,12 +12,29 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" using namespace mlir; +/// Initialize the patterns for a canonicalization pass. Collect +/// canonicalization patterns from all currently loaded dialects and registered +/// operations. +static FrozenRewritePatternSet +initializeCanonicalizer(MLIRContext *context, + ArrayRef disabledPatterns, + ArrayRef enabledPatterns) { + RewritePatternSet owningPatterns(context); + for (auto *dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); + for (RegisteredOperationName op : context->getRegisteredOperations()) + op.getCanonicalizationPatterns(owningPatterns, context); + + return {std::move(owningPatterns), disabledPatterns, enabledPatterns}; +} + namespace { /// Canonicalize operations in nested regions. struct Canonicalizer : public CanonicalizerBase { @@ -40,14 +57,8 @@ /// Initialize the canonicalizer by building the set of patterns used during /// execution. LogicalResult initialize(MLIRContext *context) override { - RewritePatternSet owningPatterns(context); - for (auto *dialect : context->getLoadedDialects()) - dialect->getCanonicalizationPatterns(owningPatterns); - for (RegisteredOperationName op : context->getRegisteredOperations()) - op.getCanonicalizationPatterns(owningPatterns, context); - - patterns = FrozenRewritePatternSet(std::move(owningPatterns), - disabledPatterns, enabledPatterns); + patterns = + initializeCanonicalizer(context, disabledPatterns, enabledPatterns); return success(); } void runOnOperation() override { @@ -55,7 +66,9 @@ config); } + /// The greedy rewrite config to use when applying patterns. GreedyRewriteConfig config; + /// The canonicalization patterns. FrozenRewritePatternSet patterns; }; } // namespace @@ -73,3 +86,14 @@ return std::make_unique(config, disabledPatterns, enabledPatterns); } + +/// Run canonicalization on the provided operation. +LogicalResult +mlir::canonicalizeOperations(Operation *op, const GreedyRewriteConfig &config, + ArrayRef disabledPatterns, + ArrayRef enabledPatterns, + RewriteListener *listener) { + FrozenRewritePatternSet patterns = initializeCanonicalizer( + op->getContext(), disabledPatterns, enabledPatterns); + return applyPatternsAndFoldGreedily(op, patterns, config, listener); +}