diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -51,7 +51,7 @@ /// Creates a pass to replace affine memref accesses by scalars using store to /// load forwarding and redundant load elimination; consequently also eliminate /// dead allocs. -std::unique_ptr> createAffineScalarReplacementPass(); +std::unique_ptr createAffineScalarReplacementPass(); /// Creates a pass to perform tiling on loop nests. std::unique_ptr> diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -14,19 +14,17 @@ #define MLIR_DIALECT_AFFINE_UTILS_H #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" namespace mlir { class AffineForOp; class AffineIfOp; class AffineParallelOp; -struct LogicalResult; -struct LoopReduction; +class DominanceInfo; class Operation; +class PostDominanceInfo; + +struct LogicalResult; using ReductionLoopMap = DenseMap>; @@ -90,6 +88,12 @@ ReductionLoopMap reductionLoops; }; +/// Replace affine store and load accesses by scalars by forwarding stores to +/// loads and eliminate invariant affine loads; consequently, eliminate dead +/// allocs. +void affineScalarReplace(FuncOp f, DominanceInfo &domInfo, + PostDominanceInfo &postDomInfo); + /// Vectorizes affine loops in 'loops' using the n-D vectorization factors in /// 'vectorSizes'. By default, each vectorization factor is applied /// inner-to-outer to the loops of each loop nest. 'fastestVaryingPattern' can diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp @@ -14,477 +14,31 @@ // SSA scalars live out of 'affine.for'/'affine.if' statements is available. //===----------------------------------------------------------------------===// -#include "PassDetail.h" -#include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/Utils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/IR/Dominance.h" #include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/SmallPtrSet.h" #include -#define DEBUG_TYPE "memref-dataflow-opt" +#define DEBUG_TYPE "affine-scalrep" using namespace mlir; namespace { -// The store to load forwarding and load CSE rely on three conditions: -// -// 1) store/load providing a replacement value and load being replaced need to -// have mathematically equivalent affine access functions (checked after full -// composition of load/store operands); this implies that they access the same -// single memref element for all iterations of the common surrounding loop, -// -// 2) the store/load op should dominate the load op, -// -// 3) no operation that may write to memory read by the load being replaced can -// occur after executing the instruction (load or store) providing the -// replacement value and before the load being replaced (thus potentially -// allowing overwriting the memory read by the load). -// -// The above conditions are simple to check, sufficient, and powerful for most -// cases in practice - they are sufficient, but not necessary --- since they -// don't reason about loops that are guaranteed to execute at least once or -// multiple sources to forward from. -// -// TODO: more forwarding can be done when support for -// loop/conditional live-out SSA values is available. -// TODO: do general dead store elimination for memref's. This pass -// currently only eliminates the stores only if no other loads/uses (other -// than dealloc) remain. -// struct AffineScalarReplacement : public AffineScalarReplacementBase { void runOnFunction() override; - - LogicalResult forwardStoreToLoad(AffineReadOpInterface loadOp, - SmallVectorImpl &loadOpsToErase, - SmallPtrSetImpl &memrefsToErase, - DominanceInfo &domInfo); - - void loadCSE(AffineReadOpInterface loadOp, - SmallVectorImpl &loadOpsToErase, - DominanceInfo &domInfo); - - void findUnusedStore(AffineWriteOpInterface storeOp, - SmallVectorImpl &storeOpsToErase, - SmallPtrSetImpl &memrefsToErase, - PostDominanceInfo &postDominanceInfo); }; } // namespace -/// Creates a pass to perform optimizations relying on memref dataflow such as -/// store to load forwarding, elimination of dead stores, and dead allocs. -std::unique_ptr> -mlir::createAffineScalarReplacementPass() { +std::unique_ptr mlir::createAffineScalarReplacementPass() { return std::make_unique(); } -/// Ensure that all operations that could be executed after `start` -/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path -/// between the operations) do not have the potential memory effect -/// `EffectType` on `memOp`. `memOp` is an operation that reads or writes to -/// a memref. For example, if `EffectType` is MemoryEffects::Write, this method -/// will check if there is no write to the memory between `start` and `memOp` -/// that would change the read within `memOp`. -template -bool hasNoInterveningEffect(Operation *start, T memOp) { - - Value memref = memOp.getMemRef(); - bool isOriginalAllocation = memref.getDefiningOp() || - memref.getDefiningOp(); - - // A boolean representing whether an intervening operation could have impacted - // memOp. - bool hasSideEffect = false; - - // Check whether the effect on memOp can be caused by a given operation op. - std::function checkOperation = [&](Operation *op) { - // If the effect has alreay been found, early exit, - if (hasSideEffect) - return; - - if (auto memEffect = dyn_cast(op)) { - SmallVector effects; - memEffect.getEffects(effects); - - bool opMayHaveEffect = false; - for (auto effect : effects) { - // If op causes EffectType on a potentially aliasing location for - // memOp, mark as having the effect. - if (isa(effect.getEffect())) { - if (isOriginalAllocation && effect.getValue() && - (effect.getValue().getDefiningOp() || - effect.getValue().getDefiningOp())) { - if (effect.getValue() != memref) - continue; - } - opMayHaveEffect = true; - break; - } - } - - if (!opMayHaveEffect) - return; - - // If the side effect comes from an affine read or write, try to - // prove the side effecting `op` cannot reach `memOp`. - if (isa(op)) { - MemRefAccess srcAccess(op); - MemRefAccess destAccess(memOp); - // Dependence analysis is only correct if both ops operate on the same - // memref. - if (srcAccess.memref == destAccess.memref) { - FlatAffineValueConstraints dependenceConstraints; - - // Number of loops containing the start op and the ending operation. - unsigned minSurroundingLoops = - getNumCommonSurroundingLoops(*start, *memOp); - - // Number of loops containing the operation `op` which has the - // potential memory side effect and can occur on a path between - // `start` and `memOp`. - unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp); - - // For ease, let's consider the case that `op` is a store and we're - // looking for other potential stores (e.g `op`) that overwrite memory - // after `start`, and before being read in `memOp`. In this case, we - // only need to consider other potential stores with depth > - // minSurrounding loops since `start` would overwrite any store with a - // smaller number of surrounding loops before. - unsigned d; - for (d = nsLoops + 1; d > minSurroundingLoops; d--) { - DependenceResult result = checkMemrefAccessDependence( - srcAccess, destAccess, d, &dependenceConstraints, - /*dependenceComponents=*/nullptr); - if (hasDependence(result)) { - hasSideEffect = true; - return; - } - } - - // No side effect was seen, simply return. - return; - } - } - hasSideEffect = true; - return; - } - - if (op->hasTrait()) { - // Recurse into the regions for this op and check whether the internal - // operations may have the side effect `EffectType` on memOp. - for (Region ®ion : op->getRegions()) - for (Block &block : region) - for (Operation &op : block) - checkOperation(&op); - return; - } - - // Otherwise, conservatively assume generic operations have the effect - // on the operation - hasSideEffect = true; - return; - }; - - // Check all paths from ancestor op `parent` to the operation `to` for the - // effect. It is known that `to` must be contained within `parent`. - auto until = [&](Operation *parent, Operation *to) { - // TODO check only the paths from `parent` to `to`. - // Currently we fallback and check the entire parent op, rather than - // just the paths from the parent path, stopping after reaching `to`. - // This is conservatively correct, but could be made more aggressive. - assert(parent->isAncestor(to)); - checkOperation(parent); - }; - - // Check for all paths from operation `from` to operation `untilOp` for the - // given memory effect. - std::function recur = - [&](Operation *from, Operation *untilOp) { - assert( - from->getParentRegion()->isAncestor(untilOp->getParentRegion()) && - "Checking for side effect between two operations without a common " - "ancestor"); - - // If the operations are in different regions, recursively consider all - // path from `from` to the parent of `to` and all paths from the parent - // of `to` to `to`. - if (from->getParentRegion() != untilOp->getParentRegion()) { - recur(from, untilOp->getParentOp()); - until(untilOp->getParentOp(), untilOp); - return; - } - - // Now, assuming that `from` and `to` exist in the same region, perform - // a CFG traversal to check all the relevant operations. - - // Additional blocks to consider. - SmallVector todoBlocks; - { - // First consider the parent block of `from` an check all operations - // after `from`. - for (auto iter = ++from->getIterator(), end = from->getBlock()->end(); - iter != end && &*iter != untilOp; ++iter) { - checkOperation(&*iter); - } - - // If the parent of `from` doesn't contain `to`, add the successors - // to the list of blocks to check. - if (untilOp->getBlock() != from->getBlock()) - for (Block *succ : from->getBlock()->getSuccessors()) - todoBlocks.push_back(succ); - } - - SmallPtrSet done; - // Traverse the CFG until hitting `to`. - while (todoBlocks.size()) { - Block *blk = todoBlocks.pop_back_val(); - if (done.count(blk)) - continue; - done.insert(blk); - for (auto &op : *blk) { - if (&op == untilOp) - break; - checkOperation(&op); - if (&op == blk->getTerminator()) - for (Block *succ : blk->getSuccessors()) - todoBlocks.push_back(succ); - } - } - }; - recur(start, memOp); - return !hasSideEffect; -} - -// This attempts to find stores which have no impact on the final result. -// A writing op writeA will be eliminated if there exists an op writeB if -// 1) writeA and writeB have mathematically equivalent affine access functions. -// 2) writeB postdominates writeA. -// 3) There is no potential read between writeA and writeB. -void AffineScalarReplacement::findUnusedStore( - AffineWriteOpInterface writeA, SmallVectorImpl &opsToErase, - SmallPtrSetImpl &memrefsToErase, - PostDominanceInfo &postDominanceInfo) { - - for (Operation *user : writeA.getMemRef().getUsers()) { - // Only consider writing operations. - auto writeB = dyn_cast(user); - if (!writeB) - continue; - - // The operations must be distinct. - if (writeB == writeA) - continue; - - // Both operations must lie in the same region. - if (writeB->getParentRegion() != writeA->getParentRegion()) - continue; - - // Both operations must write to the same memory. - MemRefAccess srcAccess(writeB); - MemRefAccess destAccess(writeA); - - if (srcAccess != destAccess) - continue; - - // writeB must postdominate writeA. - if (!postDominanceInfo.postDominates(writeB, writeA)) - continue; - - // There cannot be an operation which reads from memory between - // the two writes. - if (!hasNoInterveningEffect(writeA, writeB)) - continue; - - opsToErase.push_back(writeA); - break; - } -} - -/// Attempt to eliminate loadOp by replacing it with a value stored into memory -/// which the load is guaranteed to retrieve. This check involves three -/// components: 1) The store and load must be on the same location 2) The store -/// must dominate (and therefore must always occur prior to) the load 3) No -/// other operations will overwrite the memory loaded between the given load -/// and store. If such a value exists, the replaced `loadOp` will be added to -/// `loadOpsToErase` and its memref will be added to `memrefsToErase`. -LogicalResult AffineScalarReplacement::forwardStoreToLoad( - AffineReadOpInterface loadOp, SmallVectorImpl &loadOpsToErase, - SmallPtrSetImpl &memrefsToErase, DominanceInfo &domInfo) { - - // The store op candidate for forwarding that satisfies all conditions - // to replace the load, if any. - Operation *lastWriteStoreOp = nullptr; - - for (auto *user : loadOp.getMemRef().getUsers()) { - auto storeOp = dyn_cast(user); - if (!storeOp) - continue; - MemRefAccess srcAccess(storeOp); - MemRefAccess destAccess(loadOp); - - // 1. Check if the store and the load have mathematically equivalent - // affine access functions; this implies that they statically refer to the - // same single memref element. As an example this filters out cases like: - // store %A[%i0 + 1] - // load %A[%i0] - // store %A[%M] - // load %A[%N] - // Use the AffineValueMap difference based memref access equality checking. - if (srcAccess != destAccess) - continue; - - // 2. The store has to dominate the load op to be candidate. - if (!domInfo.dominates(storeOp, loadOp)) - continue; - - // 3. Ensure there is no intermediate operation which could replace the - // value in memory. - if (!hasNoInterveningEffect(storeOp, loadOp)) - continue; - - // We now have a candidate for forwarding. - assert(lastWriteStoreOp == nullptr && - "multiple simulataneous replacement stores"); - lastWriteStoreOp = storeOp; - } - - if (!lastWriteStoreOp) - return failure(); - - // Perform the actual store to load forwarding. - Value storeVal = - cast(lastWriteStoreOp).getValueToStore(); - // Check if 2 values have the same shape. This is needed for affine vector - // loads and stores. - if (storeVal.getType() != loadOp.getValue().getType()) - return failure(); - loadOp.getValue().replaceAllUsesWith(storeVal); - // Record the memref for a later sweep to optimize away. - memrefsToErase.insert(loadOp.getMemRef()); - // Record this to erase later. - loadOpsToErase.push_back(loadOp); - return success(); -} - -// The load to load forwarding / redundant load elimination is similar to the -// store to load forwarding. -// loadA will be be replaced with loadB if: -// 1) loadA and loadB have mathematically equivalent affine access functions. -// 2) loadB dominates loadA. -// 3) There is no write between loadA and loadB. -void AffineScalarReplacement::loadCSE( - AffineReadOpInterface loadA, SmallVectorImpl &loadOpsToErase, - DominanceInfo &domInfo) { - SmallVector loadCandidates; - for (auto *user : loadA.getMemRef().getUsers()) { - auto loadB = dyn_cast(user); - if (!loadB || loadB == loadA) - continue; - - MemRefAccess srcAccess(loadB); - MemRefAccess destAccess(loadA); - - // 1. The accesses have to be to the same location. - if (srcAccess != destAccess) { - continue; - } - - // 2. The store has to dominate the load op to be candidate. - if (!domInfo.dominates(loadB, loadA)) - continue; - - // 3. There is no write between loadA and loadB. - if (!hasNoInterveningEffect(loadB.getOperation(), - loadA)) - continue; - - // Check if two values have the same shape. This is needed for affine vector - // loads. - if (loadB.getValue().getType() != loadA.getValue().getType()) - continue; - - loadCandidates.push_back(loadB); - } - - // Of the legal load candidates, use the one that dominates all others - // to minimize the subsequent need to loadCSE - Value loadB; - for (AffineReadOpInterface option : loadCandidates) { - if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) { - return depStore == option || - domInfo.dominates(option.getOperation(), - depStore.getOperation()); - })) { - loadB = option.getValue(); - break; - } - } - - if (loadB) { - loadA.getValue().replaceAllUsesWith(loadB); - // Record this to erase later. - loadOpsToErase.push_back(loadA); - } -} - void AffineScalarReplacement::runOnFunction() { - // Only supports single block functions at the moment. - FuncOp f = getFunction(); - - // Load op's whose results were replaced by those forwarded from stores. - SmallVector opsToErase; - - // A list of memref's that are potentially dead / could be eliminated. - SmallPtrSet memrefsToErase; - - auto &domInfo = getAnalysis(); - auto &postDomInfo = getAnalysis(); - - // Walk all load's and perform store to load forwarding. - f.walk([&](AffineReadOpInterface loadOp) { - if (failed( - forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) { - loadCSE(loadOp, opsToErase, domInfo); - } - }); - - // Erase all load op's whose results were replaced with store fwd'ed ones. - for (auto *op : opsToErase) - op->erase(); - opsToErase.clear(); - - // Walk all store's and perform unused store elimination - f.walk([&](AffineWriteOpInterface storeOp) { - findUnusedStore(storeOp, opsToErase, memrefsToErase, postDomInfo); - }); - // Erase all store op's which don't impact the program - for (auto *op : opsToErase) - op->erase(); - - // Check if the store fwd'ed memrefs are now left with only stores and can - // thus be completely deleted. Note: the canonicalize pass should be able - // to do this as well, but we'll do it here since we collected these anyway. - for (auto memref : memrefsToErase) { - // If the memref hasn't been alloc'ed in this function, skip. - Operation *defOp = memref.getDefiningOp(); - if (!defOp || !isa(defOp)) - // TODO: if the memref was returned by a 'call' operation, we - // could still erase it if the call had no side-effects. - continue; - if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { - return !isa(ownerOp); - })) - continue; - - // Erase all stores, the dealloc, and the alloc on the memref. - for (auto *user : llvm::make_early_inc_range(memref.getUsers())) - user->erase(); - defOp->erase(); - } + affineScalarReplace(getFunction(), getAnalysis(), + getAnalysis()); } diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -12,10 +12,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/Utils.h" + #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" @@ -430,3 +433,427 @@ Operation *newIV = opBuilder.create(loc, ivMap, lbOperands); op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); } + +/// Ensure that all operations that could be executed after `start` +/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path +/// between the operations) do not have the potential memory effect +/// `EffectType` on `memOp`. `memOp` is an operation that reads or writes to +/// a memref. For example, if `EffectType` is MemoryEffects::Write, this method +/// will check if there is no write to the memory between `start` and `memOp` +/// that would change the read within `memOp`. +template +static bool hasNoInterveningEffect(Operation *start, T memOp) { + Value memref = memOp.getMemRef(); + bool isOriginalAllocation = memref.getDefiningOp() || + memref.getDefiningOp(); + + // A boolean representing whether an intervening operation could have impacted + // memOp. + bool hasSideEffect = false; + + // Check whether the effect on memOp can be caused by a given operation op. + std::function checkOperation = [&](Operation *op) { + // If the effect has alreay been found, early exit, + if (hasSideEffect) + return; + + if (auto memEffect = dyn_cast(op)) { + SmallVector effects; + memEffect.getEffects(effects); + + bool opMayHaveEffect = false; + for (auto effect : effects) { + // If op causes EffectType on a potentially aliasing location for + // memOp, mark as having the effect. + if (isa(effect.getEffect())) { + if (isOriginalAllocation && effect.getValue() && + (effect.getValue().getDefiningOp() || + effect.getValue().getDefiningOp())) { + if (effect.getValue() != memref) + continue; + } + opMayHaveEffect = true; + break; + } + } + + if (!opMayHaveEffect) + return; + + // If the side effect comes from an affine read or write, try to + // prove the side effecting `op` cannot reach `memOp`. + if (isa(op)) { + MemRefAccess srcAccess(op); + MemRefAccess destAccess(memOp); + // Dependence analysis is only correct if both ops operate on the same + // memref. + if (srcAccess.memref == destAccess.memref) { + FlatAffineValueConstraints dependenceConstraints; + + // Number of loops containing the start op and the ending operation. + unsigned minSurroundingLoops = + getNumCommonSurroundingLoops(*start, *memOp); + + // Number of loops containing the operation `op` which has the + // potential memory side effect and can occur on a path between + // `start` and `memOp`. + unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp); + + // For ease, let's consider the case that `op` is a store and we're + // looking for other potential stores (e.g `op`) that overwrite memory + // after `start`, and before being read in `memOp`. In this case, we + // only need to consider other potential stores with depth > + // minSurrounding loops since `start` would overwrite any store with a + // smaller number of surrounding loops before. + unsigned d; + for (d = nsLoops + 1; d > minSurroundingLoops; d--) { + DependenceResult result = checkMemrefAccessDependence( + srcAccess, destAccess, d, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (hasDependence(result)) { + hasSideEffect = true; + return; + } + } + + // No side effect was seen, simply return. + return; + } + } + hasSideEffect = true; + return; + } + + if (op->hasTrait()) { + // Recurse into the regions for this op and check whether the internal + // operations may have the side effect `EffectType` on memOp. + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &op : block) + checkOperation(&op); + return; + } + + // Otherwise, conservatively assume generic operations have the effect + // on the operation + hasSideEffect = true; + return; + }; + + // Check all paths from ancestor op `parent` to the operation `to` for the + // effect. It is known that `to` must be contained within `parent`. + auto until = [&](Operation *parent, Operation *to) { + // TODO check only the paths from `parent` to `to`. + // Currently we fallback and check the entire parent op, rather than + // just the paths from the parent path, stopping after reaching `to`. + // This is conservatively correct, but could be made more aggressive. + assert(parent->isAncestor(to)); + checkOperation(parent); + }; + + // Check for all paths from operation `from` to operation `untilOp` for the + // given memory effect. + std::function recur = + [&](Operation *from, Operation *untilOp) { + assert( + from->getParentRegion()->isAncestor(untilOp->getParentRegion()) && + "Checking for side effect between two operations without a common " + "ancestor"); + + // If the operations are in different regions, recursively consider all + // path from `from` to the parent of `to` and all paths from the parent + // of `to` to `to`. + if (from->getParentRegion() != untilOp->getParentRegion()) { + recur(from, untilOp->getParentOp()); + until(untilOp->getParentOp(), untilOp); + return; + } + + // Now, assuming that `from` and `to` exist in the same region, perform + // a CFG traversal to check all the relevant operations. + + // Additional blocks to consider. + SmallVector todoBlocks; + { + // First consider the parent block of `from` an check all operations + // after `from`. + for (auto iter = ++from->getIterator(), end = from->getBlock()->end(); + iter != end && &*iter != untilOp; ++iter) { + checkOperation(&*iter); + } + + // If the parent of `from` doesn't contain `to`, add the successors + // to the list of blocks to check. + if (untilOp->getBlock() != from->getBlock()) + for (Block *succ : from->getBlock()->getSuccessors()) + todoBlocks.push_back(succ); + } + + SmallPtrSet done; + // Traverse the CFG until hitting `to`. + while (todoBlocks.size()) { + Block *blk = todoBlocks.pop_back_val(); + if (done.count(blk)) + continue; + done.insert(blk); + for (auto &op : *blk) { + if (&op == untilOp) + break; + checkOperation(&op); + if (&op == blk->getTerminator()) + for (Block *succ : blk->getSuccessors()) + todoBlocks.push_back(succ); + } + } + }; + recur(start, memOp); + return !hasSideEffect; +} + +/// Attempt to eliminate loadOp by replacing it with a value stored into memory +/// which the load is guaranteed to retrieve. This check involves three +/// components: 1) The store and load must be on the same location 2) The store +/// must dominate (and therefore must always occur prior to) the load 3) No +/// other operations will overwrite the memory loaded between the given load +/// and store. If such a value exists, the replaced `loadOp` will be added to +/// `loadOpsToErase` and its memref will be added to `memrefsToErase`. +static LogicalResult forwardStoreToLoad( + AffineReadOpInterface loadOp, SmallVectorImpl &loadOpsToErase, + SmallPtrSetImpl &memrefsToErase, DominanceInfo &domInfo) { + + // The store op candidate for forwarding that satisfies all conditions + // to replace the load, if any. + Operation *lastWriteStoreOp = nullptr; + + for (auto *user : loadOp.getMemRef().getUsers()) { + auto storeOp = dyn_cast(user); + if (!storeOp) + continue; + MemRefAccess srcAccess(storeOp); + MemRefAccess destAccess(loadOp); + + // 1. Check if the store and the load have mathematically equivalent + // affine access functions; this implies that they statically refer to the + // same single memref element. As an example this filters out cases like: + // store %A[%i0 + 1] + // load %A[%i0] + // store %A[%M] + // load %A[%N] + // Use the AffineValueMap difference based memref access equality checking. + if (srcAccess != destAccess) + continue; + + // 2. The store has to dominate the load op to be candidate. + if (!domInfo.dominates(storeOp, loadOp)) + continue; + + // 3. Ensure there is no intermediate operation which could replace the + // value in memory. + if (!hasNoInterveningEffect(storeOp, loadOp)) + continue; + + // We now have a candidate for forwarding. + assert(lastWriteStoreOp == nullptr && + "multiple simulataneous replacement stores"); + lastWriteStoreOp = storeOp; + } + + if (!lastWriteStoreOp) + return failure(); + + // Perform the actual store to load forwarding. + Value storeVal = + cast(lastWriteStoreOp).getValueToStore(); + // Check if 2 values have the same shape. This is needed for affine vector + // loads and stores. + if (storeVal.getType() != loadOp.getValue().getType()) + return failure(); + loadOp.getValue().replaceAllUsesWith(storeVal); + // Record the memref for a later sweep to optimize away. + memrefsToErase.insert(loadOp.getMemRef()); + // Record this to erase later. + loadOpsToErase.push_back(loadOp); + return success(); +} + +// This attempts to find stores which have no impact on the final result. +// A writing op writeA will be eliminated if there exists an op writeB if +// 1) writeA and writeB have mathematically equivalent affine access functions. +// 2) writeB postdominates writeA. +// 3) There is no potential read between writeA and writeB. +static void findUnusedStore(AffineWriteOpInterface writeA, + SmallVectorImpl &opsToErase, + SmallPtrSetImpl &memrefsToErase, + PostDominanceInfo &postDominanceInfo) { + + for (Operation *user : writeA.getMemRef().getUsers()) { + // Only consider writing operations. + auto writeB = dyn_cast(user); + if (!writeB) + continue; + + // The operations must be distinct. + if (writeB == writeA) + continue; + + // Both operations must lie in the same region. + if (writeB->getParentRegion() != writeA->getParentRegion()) + continue; + + // Both operations must write to the same memory. + MemRefAccess srcAccess(writeB); + MemRefAccess destAccess(writeA); + + if (srcAccess != destAccess) + continue; + + // writeB must postdominate writeA. + if (!postDominanceInfo.postDominates(writeB, writeA)) + continue; + + // There cannot be an operation which reads from memory between + // the two writes. + if (!hasNoInterveningEffect(writeA, writeB)) + continue; + + opsToErase.push_back(writeA); + break; + } +} + +// The load to load forwarding / redundant load elimination is similar to the +// store to load forwarding. +// loadA will be be replaced with loadB if: +// 1) loadA and loadB have mathematically equivalent affine access functions. +// 2) loadB dominates loadA. +// 3) There is no write between loadA and loadB. +static void loadCSE(AffineReadOpInterface loadA, + SmallVectorImpl &loadOpsToErase, + DominanceInfo &domInfo) { + SmallVector loadCandidates; + for (auto *user : loadA.getMemRef().getUsers()) { + auto loadB = dyn_cast(user); + if (!loadB || loadB == loadA) + continue; + + MemRefAccess srcAccess(loadB); + MemRefAccess destAccess(loadA); + + // 1. The accesses have to be to the same location. + if (srcAccess != destAccess) { + continue; + } + + // 2. The store has to dominate the load op to be candidate. + if (!domInfo.dominates(loadB, loadA)) + continue; + + // 3. There is no write between loadA and loadB. + if (!hasNoInterveningEffect(loadB.getOperation(), + loadA)) + continue; + + // Check if two values have the same shape. This is needed for affine vector + // loads. + if (loadB.getValue().getType() != loadA.getValue().getType()) + continue; + + loadCandidates.push_back(loadB); + } + + // Of the legal load candidates, use the one that dominates all others + // to minimize the subsequent need to loadCSE + Value loadB; + for (AffineReadOpInterface option : loadCandidates) { + if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) { + return depStore == option || + domInfo.dominates(option.getOperation(), + depStore.getOperation()); + })) { + loadB = option.getValue(); + break; + } + } + + if (loadB) { + loadA.getValue().replaceAllUsesWith(loadB); + // Record this to erase later. + loadOpsToErase.push_back(loadA); + } +} + +// The store to load forwarding and load CSE rely on three conditions: +// +// 1) store/load providing a replacement value and load being replaced need to +// have mathematically equivalent affine access functions (checked after full +// composition of load/store operands); this implies that they access the same +// single memref element for all iterations of the common surrounding loop, +// +// 2) the store/load op should dominate the load op, +// +// 3) no operation that may write to memory read by the load being replaced can +// occur after executing the instruction (load or store) providing the +// replacement value and before the load being replaced (thus potentially +// allowing overwriting the memory read by the load). +// +// The above conditions are simple to check, sufficient, and powerful for most +// cases in practice - they are sufficient, but not necessary --- since they +// don't reason about loops that are guaranteed to execute at least once or +// multiple sources to forward from. +// +// TODO: more forwarding can be done when support for +// loop/conditional live-out SSA values is available. +// TODO: do general dead store elimination for memref's. This pass +// currently only eliminates the stores only if no other loads/uses (other +// than dealloc) remain. +// +void mlir::affineScalarReplace(FuncOp f, DominanceInfo &domInfo, + PostDominanceInfo &postDomInfo) { + // Load op's whose results were replaced by those forwarded from stores. + SmallVector opsToErase; + + // A list of memref's that are potentially dead / could be eliminated. + SmallPtrSet memrefsToErase; + + // Walk all load's and perform store to load forwarding. + f.walk([&](AffineReadOpInterface loadOp) { + if (failed( + forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) { + loadCSE(loadOp, opsToErase, domInfo); + } + }); + + // Erase all load op's whose results were replaced with store fwd'ed ones. + for (auto *op : opsToErase) + op->erase(); + opsToErase.clear(); + + // Walk all store's and perform unused store elimination + f.walk([&](AffineWriteOpInterface storeOp) { + findUnusedStore(storeOp, opsToErase, memrefsToErase, postDomInfo); + }); + // Erase all store op's which don't impact the program + for (auto *op : opsToErase) + op->erase(); + + // Check if the store fwd'ed memrefs are now left with only stores and can + // thus be completely deleted. Note: the canonicalize pass should be able + // to do this as well, but we'll do it here since we collected these anyway. + for (auto memref : memrefsToErase) { + // If the memref hasn't been alloc'ed in this function, skip. + Operation *defOp = memref.getDefiningOp(); + if (!defOp || !isa(defOp)) + // TODO: if the memref was returned by a 'call' operation, we + // could still erase it if the call had no side-effects. + continue; + if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { + return !isa(ownerOp); + })) + continue; + + // Erase all stores, the dealloc, and the alloc on the memref. + for (auto *user : llvm::make_early_inc_range(memref.getUsers())) + user->erase(); + defOp->erase(); + } +}