diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -627,6 +627,12 @@ /// in-place operation modification is about to happen. void replaceUsesWithIf(Value from, Value to, function_ref functor); + void replaceUsesWithIf(ValueRange from, ValueRange to, + function_ref functor) { + assert(from.size() == to.size() && "incorrect number of replacements"); + for (auto it : llvm::zip(from, to)) + replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor); + } /// Find uses of `from` and replace them with `to` except if the user is /// `exceptedUser`. It also marks every modified uses and notifies the diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/CSE.h @@ -0,0 +1,32 @@ +//===- CSE.h - Common Subexpression Elimination -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares methods for eliminating common subexpressions. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_CSE_H_ +#define MLIR_TRANSFORMS_CSE_H_ + +namespace mlir { + +class DominanceInfo; +class Operation; +class RewriterBase; + +/// Eliminate common subexpressions within the given operation. This transform +/// looks for and deduplicates equivalent operations. +/// +/// `changed` indicates whether the IR was modified or not. +void eliminateCommonSubExpressions(RewriterBase &rewriter, + DominanceInfo &domInfo, Operation *op, + bool *changed = nullptr); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_CSE_H_ diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -11,11 +11,13 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/CSE.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" @@ -56,7 +58,18 @@ namespace { /// Simple common sub-expression elimination. -struct CSE : public impl::CSEBase { +class CSEDriver { +public: + CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) + : rewriter(rewriter), domInfo(domInfo) {} + + /// Simplify all operations within the given op. + void simplify(Operation *op, bool *changed = nullptr); + + int64_t getNumCSE() const { return numCSE; } + int64_t getNumDCE() const { return numDCE; } + +private: /// Shared implementation of operation elimination and scoped map definitions. using AllocatorTy = llvm::RecyclingAllocator< llvm::BumpPtrAllocator, @@ -94,9 +107,6 @@ void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance); void simplifyRegion(ScopedMapTy &knownValues, Region ®ion); - void runOnOperation() override; - -private: void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, Operation *existing, bool hasSSADominance); @@ -104,29 +114,52 @@ /// between the two operations. bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp); + /// A rewriter for modifying the IR. + RewriterBase &rewriter; + /// Operations marked as dead and to be erased. std::vector opsToErase; DominanceInfo *domInfo = nullptr; MemEffectsCache memEffectsCache; + + // Various statistics. + int64_t numCSE = 0; + int64_t numDCE = 0; }; } // namespace -void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, - Operation *existing, bool hasSSADominance) { +void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, + Operation *existing, + bool hasSSADominance) { // If we find one then replace all uses of the current operation with the // existing one and mark it for deletion. We can only replace an operand in // an operation if it has not been visited yet. if (hasSSADominance) { // If the region has SSA dominance, then we are guaranteed to have not // visited any use of the current operation. - op->replaceAllUsesWith(existing); + if (auto *rewriteListener = + dyn_cast_if_present(rewriter.getListener())) + rewriteListener->notifyOperationReplaced(op, existing); + // Replace all uses, but do not remote the operation yet. This does not + // notify the listener because the original op is not erased. + rewriter.replaceAllUsesWith(op->getResults(), existing->getResults()); opsToErase.push_back(op); } else { // When the region does not have SSA dominance, we need to check if we // have visited a use before replacing any use. - op->replaceUsesWithIf(existing->getResults(), [&](OpOperand &operand) { + auto wasVisited = [&](OpOperand &operand) { return !knownValues.count(operand.getOwner()); - }); + }; + if (auto *rewriteListener = + dyn_cast_if_present(rewriter.getListener())) + for (Value v : op->getResults()) + if (all_of(v.getUses(), wasVisited)) + rewriteListener->notifyOperationReplaced(op, existing); + + // Replace all uses, but do not remote the operation yet. This does not + // notify the listener because the original op is not erased. + rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(), + wasVisited); // There may be some remaining uses of the operation. if (op->use_empty()) @@ -142,7 +175,8 @@ ++numCSE; } -bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) { +bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, + Operation *toOp) { assert(fromOp->getBlock() == toOp->getBlock()); assert( isa(fromOp) && @@ -183,8 +217,9 @@ } /// Attempt to eliminate a redundant operation. -LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op, - bool hasSSADominance) { +LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, + Operation *op, + bool hasSSADominance) { // Don't simplify terminator operations. if (op->hasTrait()) return failure(); @@ -240,8 +275,8 @@ return failure(); } -void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb, - bool hasSSADominance) { +void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, + bool hasSSADominance) { for (auto &op : *bb) { // Most operations don't have regions, so fast path that case. if (op.getNumRegions() != 0) { @@ -267,7 +302,7 @@ memEffectsCache.clear(); } -void CSE::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { +void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { // If the region is empty there is nothing to do. if (region.empty()) return; @@ -322,29 +357,54 @@ } } -void CSE::runOnOperation() { - /// A scoped hash table of defining operations within a region. +void CSEDriver::simplify(Operation *op, bool *changed) { + /// Simplify all regions. ScopedMapTy knownValues; - - domInfo = &getAnalysis(); - Operation *rootOp = getOperation(); - - for (auto ®ion : rootOp->getRegions()) + for (auto ®ion : op->getRegions()) simplifyRegion(knownValues, region); - // If no operations were erased, then we mark all analyses as preserved. - if (opsToErase.empty()) - return markAllAnalysesPreserved(); - /// Erase any operations that were marked as dead during simplification. for (auto *op : opsToErase) - op->erase(); - opsToErase.clear(); + rewriter.eraseOp(op); + if (changed) + *changed = !opsToErase.empty(); + + // Note: CSE does currently not remove ops with regions, so DominanceInfo + // does not have to be invalidated. +} + +void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter, + DominanceInfo &domInfo, Operation *op, + bool *changed) { + CSEDriver driver(rewriter, &domInfo); + driver.simplify(op, changed); +} + +namespace { +/// CSE pass. +struct CSE : public impl::CSEBase { + void runOnOperation() override; +}; +} // namespace + +void CSE::runOnOperation() { + // Simplify the IR. + IRRewriter rewriter(&getContext()); + CSEDriver driver(rewriter, &getAnalysis()); + bool changed = false; + driver.simplify(getOperation(), &changed); + + // Set statistics. + numCSE = driver.getNumCSE(); + numDCE = driver.getNumDCE(); + + // If there was no change to the IR, we mark all analyses as preserved. + if (!changed) + return markAllAnalysesPreserved(); // We currently don't remove region operations, so mark dominance as // preserved. markAnalysesPreserved(); - domInfo = nullptr; } std::unique_ptr mlir::createCSEPass() { return std::make_unique(); }