diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -555,9 +555,14 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. - void - replaceUsesWithIf(Value from, Value to, - llvm::unique_function functor); + void replaceUsesWithIf(Value from, Value to, + function_ref functor); + void replaceUsesWithIf(ValueRange from, ValueRange to, + function_ref functor) { + assert(from.size() == to.size() && "incorrect number of replacements"); + for (auto it : llvm::zip(from, to)) + replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor); + } /// Find uses of `from` and replace them with `to` except if the user is /// `exceptedUser`. It also marks every modified uses and notifies the diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/CSE.h @@ -0,0 +1,37 @@ +//===- CSE.h - Common Subexpression Elimination -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares methods for eliminating common subexpressions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_CSE_H_ +#define MLIR_TRANSFORMS_CSE_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { + +/// This class allows control over how common subexpression elimination works. +class CSERewriteConfig { +public: + /// An optional listener that should be notified about IR modifications. + RewriterBase::Listener *listener = nullptr; +}; + +/// Eliminate common subexpressions within the given operation. This transform +/// looks for and deduplicates duplicate equivalent operations. +/// +/// `changed` indicates whether the IR was modified or not. +void eliminateCommonSubExpressions(Operation *op, + CSERewriteConfig config = CSERewriteConfig(), + bool *changed = nullptr); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_CSE_H_ diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -326,9 +326,8 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. -void RewriterBase::replaceUsesWithIf( - Value from, Value to, - llvm::unique_function functor) { +void RewriterBase::replaceUsesWithIf(Value from, Value to, + function_ref functor) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { if (functor(operand)) updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); }); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -11,11 +11,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/CSE.h" #include "mlir/IR/Dominance.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" @@ -56,7 +57,18 @@ namespace { /// Simple common sub-expression elimination. -struct CSE : public impl::CSEBase { +class CSEDriver { +public: + CSEDriver(MLIRContext *ctx, DominanceInfo &domInfo, + CSERewriteConfig config = CSERewriteConfig()); + + /// Simplify all operations within the given op. + void simplify(Operation *op, bool *changed = nullptr); + + int64_t getNumCSE() const { return numCSE; } + int64_t getNumDCE() const { return numDCE; } + +private: /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, @@ -94,9 +106,6 @@ void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance); void simplifyRegion(ScopedMapTy &knownValues, Region ®ion); - void runOnOperation() override; - -private: void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, Operation *existing, bool hasSSADominance); @@ -104,33 +113,56 @@ /// between the two operations. bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp); + /// A rewriter for modifying the IR. + IRRewriter rewriter; + + /// Configuration information for how to simplify. + CSERewriteConfig config; + /// Operations marked as dead and to be erased. std::vector opsToErase; - DominanceInfo *domInfo = nullptr; + DominanceInfo &domInfo; MemEffectsCache memEffectsCache; + + // Various statistics. + int64_t numCSE = 0; + int64_t numDCE = 0; }; } // namespace -void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, - Operation *existing, bool hasSSADominance) { +CSEDriver::CSEDriver(MLIRContext *ctx, DominanceInfo &domInfo, + CSERewriteConfig config) + : rewriter(ctx), config(config), domInfo(domInfo) { + rewriter.setListener(config.listener); +} + +void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, + Operation *existing, + bool hasSSADominance) { // If we find one then replace all uses of the current operation with the // existing one and mark it for deletion. We can only replace an operand in // an operation if it has not been visited yet. if (hasSSADominance) { // If the region has SSA dominance, then we are guaranteed to have not // visited any use of the current operation. - op->replaceAllUsesWith(existing); + rewriter.replaceAllUsesWith(op->getResults(), existing->getResults()); opsToErase.push_back(op); + if (config.listener) + config.listener->notifyOperationReplaced(op, existing->getResults()); } else { // When the region does not have SSA dominance, we need to check if we // have visited a use before replacing any use. - op->replaceUsesWithIf(existing->getResults(), [&](OpOperand &operand) { - return !knownValues.count(operand.getOwner()); - }); + rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(), + [&](OpOperand &operand) { + return !knownValues.count(operand.getOwner()); + }); // There may be some remaining uses of the operation. - if (op->use_empty()) + if (op->use_empty()) { opsToErase.push_back(op); + if (config.listener) + config.listener->notifyOperationReplaced(op, existing->getResults()); + } } // If the existing operation has an unknown location and the current @@ -142,7 +174,8 @@ ++numCSE; } -bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) { +bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, + Operation *toOp) { assert(fromOp->getBlock() == toOp->getBlock()); assert( isa(fromOp) && @@ -183,8 +216,9 @@ } /// Attempt to eliminate a redundant operation. -LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op, - bool hasSSADominance) { +LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, + Operation *op, + bool hasSSADominance) { // Don't simplify terminator operations. if (op->hasTrait()) return failure(); @@ -240,8 +274,8 @@ return failure(); } -void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb, - bool hasSSADominance) { +void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, + bool hasSSADominance) { for (auto &op : *bb) { // Most operations don't have regions, so fast path that case. if (op.getNumRegions() != 0) { @@ -267,12 +301,12 @@ memEffectsCache.clear(); } -void CSE::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { +void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { // If the region is empty there is nothing to do. if (region.empty()) return; - bool hasSSADominance = domInfo->hasSSADominance(®ion); + bool hasSSADominance = domInfo.hasSSADominance(®ion); // If the region only contains one block, then simplify it directly. if (region.hasOneBlock()) { @@ -297,7 +331,7 @@ // Process the nodes of the dom tree for this region. stack.emplace_back(std::make_unique( - knownValues, domInfo->getRootNode(®ion))); + knownValues, domInfo.getRootNode(®ion))); while (!stack.empty()) { auto ¤tNode = stack.back(); @@ -322,29 +356,54 @@ } } -void CSE::runOnOperation() { - /// A scoped hash table of defining operations within a region. +void CSEDriver::simplify(Operation *op, bool *changed) { + /// Simplify all regions. ScopedMapTy knownValues; - - domInfo = &getAnalysis(); - Operation *rootOp = getOperation(); - - for (auto ®ion : rootOp->getRegions()) + for (auto ®ion : op->getRegions()) simplifyRegion(knownValues, region); - // If no operations were erased, then we mark all analyses as preserved. - if (opsToErase.empty()) - return markAllAnalysesPreserved(); - /// Erase any operations that were marked as dead during simplification. for (auto *op : opsToErase) - op->erase(); - opsToErase.clear(); + rewriter.eraseOp(op); + if (changed) + *changed = !opsToErase.empty(); + + /// Invalidate dominance info if the IR was changed. + if (!opsToErase.empty()) + domInfo.invalidate(); +} + +namespace { +/// CSE pass. +struct CSE : public impl::CSEBase { + void runOnOperation() override; +}; +} // namespace + +void CSE::runOnOperation() { + // Simplify the IR. + CSEDriver driver(&getContext(), getAnalysis()); + bool changed = false; + driver.simplify(getOperation(), &changed); + + // Set statistics. + numCSE = driver.getNumCSE(); + numDCE = driver.getNumDCE(); + + // If there was no change to the IR, we mark all analyses as preserved. + if (!changed) + return markAllAnalysesPreserved(); // We currently don't remove region operations, so mark dominance as // preserved. markAnalysesPreserved(); - domInfo = nullptr; +} + +void mlir::eliminateCommonSubExpressions(Operation *op, CSERewriteConfig config, + bool *changed) { + DominanceInfo domInfo; + CSEDriver driver(op->getContext(), domInfo, config); + driver.simplify(op, changed); } std::unique_ptr mlir::createCSEPass() { return std::make_unique(); }