diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -234,6 +234,21 @@ // TODO: add support for non-unit strides. LogicalResult addAffineForOpDomain(AffineForOp forOp); + /// Adds constraints (lower and upper bounds) for each loop in the loop nest + /// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice. + /// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in + /// the nest, sorted outer-to-inner. 'operands' contains the bound operands + /// for a single bound map. All the bound maps will use the same bound + /// operands. Note that some loops described by a computation slice might not + /// exist yet in the IR so the Value attached to those dimension identifiers + /// might be empty. For that reason, this method doesn't perform Value + /// look-ups to retrieve the dimension identifier positions. Instead, it + /// assumes the position of the dim identifiers in the constraint system is + /// the same as the position of the loop in the loop nest. + LogicalResult addDomainFromSliceMaps(ArrayRef lbMaps, + ArrayRef ubMaps, + ArrayRef operands); + /// Adds constraints imposed by the `affine.if` operation. These constraints /// are collected from the IntegerSet attached to the given `affine.if` /// instance argument (`ifOp`). It is asserted that: diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -83,10 +83,25 @@ // Clears all bounds and operands in slice state. void clearBounds(); - /// Return true if the computation slice is empty. + /// Returns true if the computation slice is empty. bool isEmpty() const { return ivs.empty(); } + /// Returns true if the computation slice encloses all the iterations of the + /// sliced loop nest. Returns false if it does not. Returns llvm::None if it + /// cannot determine if the slice is maximal or not. + // TODO: Cache 'isMaximal' so that we don't recompute it when the slice + // information hasn't changed. + Optional isMaximal() const; + void dump() const; + +private: + /// Fast check to determine if the computation slice is maximal. Returns true + /// if each slice dimension maps to an existing dst dimension and both the src + /// and the dst loops for those dimensions have the same bounds. Returns false + /// if both the src and the dst loops don't have the same bounds. Returns + /// llvm::None if none of the above can be proven. + Optional isSliceMaximalFastCheck() const; }; /// Computes the computation slice loop bounds for one loop nest as affine maps diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h --- a/mlir/include/mlir/Transforms/LoopFusionUtils.h +++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h @@ -50,7 +50,8 @@ // TODO: Generalize utilities so that producer-consumer and sibling fusion // strategies can be used without the assumptions made in the AffineLoopFusion // pass. -struct FusionStrategy { +class FusionStrategy { +public: enum StrategyEnum { // Generic loop fusion: Arbitrary loops are considered for fusion. No // assumptions about a specific fusion strategy from AffineLoopFusion pass @@ -69,13 +70,34 @@ // implementation in AffineLoopFusion pass are made. See pass for specific // details. Sibling - } strategy; + }; - // Target memref for this fusion transformation. - Value memref; + /// Construct a generic or producer-consumer fusion strategy. + FusionStrategy(StrategyEnum strategy) : strategy(strategy) { + assert(strategy != Sibling && + "Sibling fusion strategy requires a specific memref"); + } + + /// Construct a sibling fusion strategy targeting 'memref'. This construct + /// should only be used for sibling fusion. + FusionStrategy(Value memref) : strategy(Sibling), memref(memref) {} + + /// Returns the fusion strategy. + StrategyEnum getStrategy() const { return strategy; }; - FusionStrategy(StrategyEnum strategy, Value memref) - : strategy(strategy), memref(memref) {} + /// Returns the memref attached to this sibling fusion strategy. + Value getSiblingFusionMemRef() const { + assert(strategy == Sibling && "Memref is only valid for sibling fusion"); + return memref; + } + +private: + /// Fusion strategy. + StrategyEnum strategy; + + /// Target memref for this fusion transformation. Only used for sibling + /// fusion. + Value memref; }; /// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the @@ -86,11 +108,10 @@ /// NOTE: This function is not feature complete and should only be used in /// testing. /// TODO: Update comments when this function is fully implemented. -FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, - unsigned dstLoopDepth, - ComputationSliceState *srcSlice, - FusionStrategy fusionStrategy = { - FusionStrategy::Generic, Value()}); +FusionResult +canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, + ComputationSliceState *srcSlice, + FusionStrategy fusionStrategy = FusionStrategy::Generic); /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point /// and source slice loop bounds specified in 'srcSlice'. @@ -134,6 +155,12 @@ const ComputationSliceState &slice, int64_t *computeCost); +/// Returns in 'producerConsumerMemrefs' the memrefs involved in a +/// producer-consumer dependence between write ops in 'srcOps' and read ops in +/// 'dstOps'. +void gatherProducerConsumerMemrefs(ArrayRef srcOps, + ArrayRef dstOps, + DenseSet &producerConsumerMemrefs); } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -17,6 +17,111 @@ def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> { let summary = "Fuse affine loop nests"; + let description = [{ + This pass performs fusion of loop nests using a slicing-based approach. It + combines two fusion strategies: producer-consumer fusion and sibling fusion. + Producer-consumer fusion is aimed at fusing pairs of loops where the first + one writes to a memref that the second reads. Sibling fusion targets pairs + of loops that share no dependences between them but that load from the same + memref. The fused loop nests, when possible, are rewritten to access + significantly smaller local buffers instead of the original memref's, and + the latter are often either completely optimized away or contracted. This + transformation leads to enhanced locality and lower memory footprint through + the elimination or contraction of temporaries/intermediate memref's. These + benefits are sometimes achieved at the expense of redundant computation + through a cost model that evaluates available choices such as the depth at + which a source slice should be materialized in the designation slice. + + Example 1: Producer-consumer fusion. + Input: + ```mlir + func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) { + %0 = alloc() : memref<10xf32> + %1 = alloc() : memref<10xf32> + %cst = constant 0.000000e+00 : f32 + affine.for %arg2 = 0 to 10 { + affine.store %cst, %0[%arg2] : memref<10xf32> + affine.store %cst, %1[%arg2] : memref<10xf32> + } + affine.for %arg2 = 0 to 10 { + %2 = affine.load %0[%arg2] : memref<10xf32> + %3 = addf %2, %2 : f32 + affine.store %3, %arg0[%arg2] : memref<10xf32> + } + affine.for %arg2 = 0 to 10 { + %2 = affine.load %1[%arg2] : memref<10xf32> + %3 = mulf %2, %2 : f32 + affine.store %3, %arg1[%arg2] : memref<10xf32> + } + return + } + ``` + Output: + ```mlir + func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) { + %0 = alloc() : memref<1xf32> + %1 = alloc() : memref<1xf32> + %cst = constant 0.000000e+00 : f32 + affine.for %arg2 = 0 to 10 { + affine.store %cst, %0[0] : memref<1xf32> + affine.store %cst, %1[0] : memref<1xf32> + %2 = affine.load %1[0] : memref<1xf32> + %3 = mulf %2, %2 : f32 + affine.store %3, %arg1[%arg2] : memref<10xf32> + %4 = affine.load %0[0] : memref<1xf32> + %5 = addf %4, %4 : f32 + affine.store %5, %arg0[%arg2] : memref<10xf32> + } + return + } + ``` + + Example 2: Sibling fusion. + Input: + ```mlir + func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>, + %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>, + %arg4: memref<10x10xf32>) { + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 3 { + %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> + %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> + %2 = mulf %0, %1 : f32 + affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> + } + } + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 3 { + %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> + %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> + %2 = addf %0, %1 : f32 + affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32> + } + } + return + } + ``` + Output: + ```mlir + func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>, + %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>, + %arg4: memref<10x10xf32>) { + affine.for %arg5 = 0 to 3 { + affine.for %arg6 = 0 to 3 { + %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> + %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32> + %2 = mulf %0, %1 : f32 + affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32> + %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32> + %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32> + %5 = addf %3, %4 : f32 + affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32> + } + } + return + } + ``` + }]; let constructor = "mlir::createLoopFusionPass()"; let options = [ Option<"computeToleranceThreshold", "fusion-compute-tolerance", "double", diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -708,6 +708,70 @@ /*eq=*/false, /*lower=*/false); } +/// Adds constraints (lower and upper bounds) for each loop in the loop nest +/// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice. +/// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in +/// the nest, sorted outer-to-inner. 'operands' contains the bound operands +/// for a single bound map. All the bound maps will use the same bound +/// operands. Note that some loops described by a computation slice might not +/// exist yet in the IR so the Value attached to those dimension identifiers +/// might be empty. For that reason, this method doesn't perform Value +/// look-ups to retrieve the dimension identifier positions. Instead, it +/// assumes the position of the dim identifiers in the constraint system is +/// the same as the position of the loop in the loop nest. +LogicalResult +FlatAffineConstraints::addDomainFromSliceMaps(ArrayRef lbMaps, + ArrayRef ubMaps, + ArrayRef operands) { + assert(lbMaps.size() == ubMaps.size()); + assert(lbMaps.size() <= getNumDimIds()); + + for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { + AffineMap lbMap = lbMaps[i]; + AffineMap ubMap = ubMaps[i]; + assert(!lbMap || lbMap.getNumInputs() == operands.size()); + assert(!ubMap || ubMap.getNumInputs() == operands.size()); + + // Check if this slice is just an equality along this dimension. If so, + // retrieve the existing loop it equates to and add it to the system. + if (lbMap && ubMap && lbMap.getNumResults() == 1 && + ubMap.getNumResults() == 1 && + lbMap.getResult(0) + 1 == ubMap.getResult(0) && + // The condition above will be true for maps describing a single + // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). + // Make sure we skip those cases by checking that the lb result is not + // just a constant. + !lbMap.getResult(0).isa()) { + // Limited support: we expect the lb result to be just a loop dimension. + // Not supported otherwise for now. + AffineDimExpr result = lbMap.getResult(0).dyn_cast(); + if (!result) + return failure(); + + AffineForOp loop = + getForInductionVarOwner(operands[result.getPosition()]); + if (!loop) + return failure(); + + if (failed(addAffineForOpDomain(loop))) + return failure(); + continue; + } + + // This slice refers to a loop that doesn't exist in the IR yet. Add its + // bounds to the system assuming its dimension identifier position is the + // same as the position of the loop in the loop nest. + if (lbMap && failed(addLowerOrUpperBound(i, lbMap, operands, /*eq=*/false, + /*lower=*/true))) + return failure(); + + if (ubMap && failed(addLowerOrUpperBound(i, ubMap, operands, /*eq=*/false, + /*lower=*/false))) + return failure(); + } + return success(); +} + void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) { // Create the base constraints from the integer set attached to ifOp. FlatAffineConstraints cst(ifOp.getIntegerSet()); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Utils.h" - #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/PresburgerSet.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -127,6 +127,128 @@ } } +/// Fast check to determine if the computation slice is maximal. Returns true if +/// each slice dimension maps to an existing dst dimension and both the src +/// and the dst loops for those dimensions have the same bounds. Returns false +/// if both the src and the dst loops don't have the same bounds. Returns +/// llvm::None if none of the above can be proven. +Optional ComputationSliceState::isSliceMaximalFastCheck() const { + assert(lbs.size() == ubs.size() && lbs.size() && ivs.size() && + "Unexpected number of lbs, ubs and ivs in slice"); + + for (unsigned i = 0, end = lbs.size(); i < end; ++i) { + AffineMap lbMap = lbs[i]; + AffineMap ubMap = ubs[i]; + + // Check if this slice is just an equality along this dimension. + if (!lbMap || !ubMap || lbMap.getNumResults() != 1 || + ubMap.getNumResults() != 1 || + lbMap.getResult(0) + 1 != ubMap.getResult(0) || + // The condition above will be true for maps describing a single + // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). + // Make sure we skip those cases by checking that the lb result is not + // just a constant. + lbMap.getResult(0).isa()) + return llvm::None; + + // Limited support: we expect the lb result to be just a loop dimension for + // now. + AffineDimExpr result = lbMap.getResult(0).dyn_cast(); + if (!result) + return llvm::None; + + // Retrieve dst loop bounds. + AffineForOp dstLoop = + getForInductionVarOwner(lbOperands[i][result.getPosition()]); + if (!dstLoop) + return llvm::None; + AffineMap dstLbMap = dstLoop.getLowerBoundMap(); + AffineMap dstUbMap = dstLoop.getUpperBoundMap(); + + // Retrieve src loop bounds. + AffineForOp srcLoop = getForInductionVarOwner(ivs[i]); + assert(srcLoop && "Expected affine for"); + AffineMap srcLbMap = srcLoop.getLowerBoundMap(); + AffineMap srcUbMap = srcLoop.getUpperBoundMap(); + + // Limited support: we expect simple src and dst loops with a single + // constant component per bound for now. + if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 || + dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1) + return llvm::None; + + AffineExpr srcLbResult = srcLbMap.getResult(0); + AffineExpr dstLbResult = dstLbMap.getResult(0); + AffineExpr srcUbResult = srcUbMap.getResult(0); + AffineExpr dstUbResult = dstUbMap.getResult(0); + if (!srcLbResult.isa() || + !srcUbResult.isa() || + !dstLbResult.isa() || + !dstUbResult.isa()) + return llvm::None; + + // Check if src and dst loop bounds are the same. If not, we can guarantee + // that the slice is not maximal. + if (srcLbResult != dstLbResult || srcUbResult != dstUbResult) + return false; + } + + return true; +} + +/// Returns true if the computation slice encloses all the iterations of the +/// sliced loop nest. Returns false if it does not. Returns llvm::None if it +/// cannot determine if the slice is maximal or not. +Optional ComputationSliceState::isMaximal() const { + // Fast check to determine if the computation slice is maximal. If the result + // is inconclusive, we proceed with a more expensive analysis. + Optional isMaximalFastCheck = isSliceMaximalFastCheck(); + if (isMaximalFastCheck.hasValue()) + return isMaximalFastCheck; + + // Create constraints for the src loop nest being sliced. + FlatAffineConstraints srcConstraints; + srcConstraints.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, + /*numLocals=*/0, ivs); + for (Value iv : ivs) { + AffineForOp loop = getForInductionVarOwner(iv); + assert(loop && "Expected affine for"); + if (failed(srcConstraints.addAffineForOpDomain(loop))) + return llvm::None; + } + + // Create constraints for the slice using the dst loop nest information. We + // retrieve existing dst loops from the lbOperands. + SmallVector consumerIVs; + for (Value lbOp : lbOperands[0]) + if (getForInductionVarOwner(lbOp)) + consumerIVs.push_back(lbOp); + + // Add empty IV Values for those new loops that are not equalities and, + // therefore, are not yet materialized in the IR. + for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i) + consumerIVs.push_back(Value()); + + FlatAffineConstraints sliceConstraints; + sliceConstraints.reset(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0, + /*numLocals=*/0, consumerIVs); + + if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0]))) + return llvm::None; + + if (srcConstraints.getNumDimIds() != sliceConstraints.getNumDimIds()) + // Constraint dims are different. The integer set difference can't be + // computed so we don't know if the slice is maximal. + return llvm::None; + + // Compute the difference between the src loop nest and the slice integer + // sets. + PresburgerSet srcSet(srcConstraints); + PresburgerSet sliceSet(sliceConstraints); + PresburgerSet diffSet = srcSet.subtract(sliceSet); + return diffSet.isIntegerEmpty(); +} + unsigned MemRefRegion::getRank() const { return memref.getType().cast().getRank(); } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -270,64 +270,6 @@ return false; } - // Returns the unique AffineWriteOpInterface in `node` that meets all the - // following: - // *) store is the only one that writes to a function-local memref live out - // of `node`, - // *) store is not the source of a self-dependence on `node`. - // Otherwise, returns a null AffineWriteOpInterface. - AffineWriteOpInterface getUniqueOutgoingStore(Node *node) { - AffineWriteOpInterface uniqueStore; - - // Return null if `node` doesn't have any outgoing edges. - auto outEdgeIt = outEdges.find(node->id); - if (outEdgeIt == outEdges.end()) - return nullptr; - - const auto &nodeOutEdges = outEdgeIt->second; - for (auto *op : node->stores) { - auto storeOp = cast(op); - auto memref = storeOp.getMemRef(); - // Skip this store if there are no dependences on its memref. This means - // that store either: - // *) writes to a memref that is only read within the same loop nest - // (self-dependence edges are not represented in graph at the moment), - // *) writes to a function live out memref (function parameter), or - // *) is dead. - if (llvm::all_of(nodeOutEdges, [=](const Edge &edge) { - return (edge.value != memref); - })) - continue; - - if (uniqueStore) - // Found multiple stores to function-local live-out memrefs. - return nullptr; - // Found first store to function-local live-out memref. - uniqueStore = storeOp; - } - - return uniqueStore; - } - - // Returns true if node 'id' can be removed from the graph. Returns false - // otherwise. A node can be removed from the graph iff the following - // conditions are met: - // *) The node does not write to any memref which escapes (or is a - // function/block argument). - // *) The node has no successors in the dependence graph. - bool canRemoveNode(unsigned id) { - if (writesToLiveInOrEscapingMemrefs(id)) - return false; - Node *node = getNode(id); - for (auto *storeOpInst : node->stores) { - // Return false if there exist out edges from 'id' on 'memref'. - auto storeMemref = cast(storeOpInst).getMemRef(); - if (getOutEdgeCount(id, storeMemref) > 0) - return false; - } - return true; - } - // Returns true iff there is an edge from node 'srcId' to node 'dstId' which // is for 'value' if non-null, or for any value otherwise. Returns false // otherwise. @@ -495,42 +437,49 @@ return dstNodeInst; } - // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' - // has been replaced in node at 'dstId' by a private memref depending - // on the value of 'createPrivateMemRef'. - void updateEdges(unsigned srcId, unsigned dstId, Value oldMemRef, - bool createPrivateMemRef) { + // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, + // taking into account that: + // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, + // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a + // private memref. + void updateEdges(unsigned srcId, unsigned dstId, + const DenseSet &privateMemRefs, bool removeSrcId) { // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. if (inEdges.count(srcId) > 0) { SmallVector oldInEdges = inEdges[srcId]; for (auto &inEdge : oldInEdges) { - // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'. - if (inEdge.value != oldMemRef) + // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. + if (privateMemRefs.count(inEdge.value) == 0) addEdge(inEdge.id, dstId, inEdge.value); } } // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. + // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. if (outEdges.count(srcId) > 0) { SmallVector oldOutEdges = outEdges[srcId]; for (auto &outEdge : oldOutEdges) { // Remove any out edges from 'srcId' to 'dstId' across memrefs. if (outEdge.id == dstId) removeEdge(srcId, outEdge.id, outEdge.value); + else if (removeSrcId) { + addEdge(dstId, outEdge.id, outEdge.value); + removeEdge(srcId, outEdge.id, outEdge.value); + } } } // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being // replaced by a private memref). These edges could come from nodes // other than 'srcId' which were removed in the previous step. - if (inEdges.count(dstId) > 0 && createPrivateMemRef) { + if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { SmallVector oldInEdges = inEdges[dstId]; for (auto &inEdge : oldInEdges) - if (inEdge.value == oldMemRef) + if (privateMemRefs.count(inEdge.value) > 0) removeEdge(inEdge.id, dstId, inEdge.value); } } // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion - // of sibling node 'sidId' into node 'dstId'. + // of sibling node 'sibId' into node 'dstId'. void updateEdges(unsigned sibId, unsigned dstId) { // For each edge in 'inEdges[sibId]': // *) Add new edge from source node 'inEdge.id' to 'dstNode'. @@ -624,6 +573,141 @@ void dump() const { print(llvm::errs()); } }; +/// Returns true if node 'srcId' can be removed after fusing it with node +/// 'dstId'. The node can be removed if any of the following conditions are met: +/// 1. 'srcId' has no output dependences after fusion and no escaping memrefs. +/// 2. 'srcId' has no output dependences after fusion, has escaping memrefs +/// and the fusion slice is maximal. +/// 3. 'srcId' has output dependences after fusion, the fusion slice is +/// maximal and the fusion insertion point dominates all the dependences. +static bool canRemoveSrcNodeAfterFusion( + unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice, + Operation *fusedLoopInsPoint, const DenseSet &escapingMemRefs, + MemRefDependenceGraph *mdg) { + + Operation *dstNodeOp = mdg->getNode(dstId)->op; + bool hasOutDepsAfterFusion = false; + + for (auto &outEdge : mdg->outEdges[srcId]) { + Operation *depNodeOp = mdg->getNode(outEdge.id)->op; + // Skip dependence with dstOp since it will be removed after fusion. + if (depNodeOp == dstNodeOp) + continue; + + // Only fusion within the same block is supported. Use domination analysis + // when needed. + if (depNodeOp->getBlock() != dstNodeOp->getBlock()) + return false; + + // Check if the insertion point of the fused loop dominates the dependence. + // Otherwise, the src loop can't be removed. + if (fusedLoopInsPoint != depNodeOp && + !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { + LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " + "dominate dependence\n"); + return false; + } + + hasOutDepsAfterFusion = true; + } + + // If src loop has dependences after fusion or it writes to an live-out or + // escaping memref, we can only remove it if the fusion slice is maximal so + // that all the dependences are preserved. + if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { + Optional isMaximal = fusionSlice.isMaximal(); + if (!isMaximal.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " + "if fusion is maximal\n"); + return false; + } + + if (!isMaximal.getValue()) { + LLVM_DEBUG(llvm::dbgs() + << "Src loop can't be removed: fusion is not maximal\n"); + return false; + } + } + + return true; +} + +/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer +/// 'dstId'. Candidates are sorted by node id order. This order corresponds to +/// the program order when the 'mdg' is created. However, program order is not +/// guaranteed and must not be required by the client. Program order won't be +/// held if the 'mdg' is reused from a previous fusion step or if the node +/// creation order changes in the future to support more advance cases. +// TODO: Move this to a loop fusion utility once 'mdg' is also moved. +static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg, + SmallVectorImpl &srcIdCandidates) { + // Skip if no input edges along which to fuse. + if (mdg->inEdges.count(dstId) == 0) + return; + + // Gather memrefs from loads in 'dstId'. + auto *dstNode = mdg->getNode(dstId); + DenseSet consumedMemrefs; + for (Operation *load : dstNode->loads) + consumedMemrefs.insert(cast(load).getMemRef()); + + // Traverse 'dstId' incoming edges and gather the nodes that contain a store + // to one of the consumed memrefs. + for (auto &srcEdge : mdg->inEdges[dstId]) { + auto *srcNode = mdg->getNode(srcEdge.id); + // Skip if 'srcNode' is not a loop nest. + if (!isa(srcNode->op)) + continue; + + if (any_of(srcNode->stores, [&](Operation *op) { + auto storeOp = cast(op); + return consumedMemrefs.count(storeOp.getMemRef()) > 0; + })) + srcIdCandidates.push_back(srcNode->id); + } + + std::sort(srcIdCandidates.begin(), srcIdCandidates.end()); + srcIdCandidates.erase( + std::unique(srcIdCandidates.begin(), srcIdCandidates.end()), + srcIdCandidates.end()); +} + +/// Returns in 'producerConsumerMemrefs' the memrefs involved in a +/// producer-consumer dependence between 'srcId' and 'dstId'. +static void +gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId, + MemRefDependenceGraph *mdg, + DenseSet &producerConsumerMemrefs) { + auto *dstNode = mdg->getNode(dstId); + auto *srcNode = mdg->getNode(srcId); + gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads, + producerConsumerMemrefs); +} + +/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' +/// that escape the function. A memref escapes the function if either: +/// 1. It's a function argument, or +/// 2. It's used by a non-affine op (e.g., std load/store, std call, etc.) +void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, + DenseSet &escapingMemRefs) { + auto *node = mdg->getNode(id); + for (auto *storeOpInst : node->stores) { + auto memref = cast(storeOpInst).getMemRef(); + if (escapingMemRefs.count(memref)) + continue; + // Check if 'memref' escapes because it's a block argument. + if (memref.isa()) { + escapingMemRefs.insert(memref); + continue; + } + // Check if 'memref' escapes through a non-affine op (e.g., std load/store, + // call op, etc.). + for (Operation *user : memref.getUsers()) + if (!isMemRefDereferencingOp(*user)) + escapingMemRefs.insert(memref); + } +} + } // end anonymous namespace // Initializes the data dependence graph by walking operations in 'f'. @@ -631,6 +715,7 @@ // TODO: Add support for taking a Block arg to construct the // dependence graph at a different depth. bool MemRefDependenceGraph::init(FuncOp f) { + LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); DenseMap> memrefAccesses; // TODO: support multi-block functions. @@ -686,6 +771,12 @@ } } + for (auto &idAndNode : nodes) { + LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" + << *(idAndNode.second.op) << "\n"); + (void)idAndNode; + } + // Add dependence edges between nodes which produce SSA values and their // users. for (auto &idAndNode : nodes) { @@ -725,22 +816,6 @@ return true; } -// Removes load operations from 'srcLoads' which operate on 'memref', and -// adds them to 'dstLoads'. -static void moveLoadsAccessingMemrefTo(Value memref, - SmallVectorImpl *srcLoads, - SmallVectorImpl *dstLoads) { - dstLoads->clear(); - SmallVector srcLoadsToKeep; - for (auto *load : *srcLoads) { - if (cast(load).getMemRef() == memref) - dstLoads->push_back(load); - else - srcLoadsToKeep.push_back(load); - } - srcLoads->swap(srcLoadsToKeep); -} - // Sinks all sequential loops to the innermost levels (while preserving // relative order among them) and moves all parallel loops to the // outermost (while again preserving relative order among them). @@ -932,75 +1007,6 @@ return false; } -// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' -// may write to multiple memrefs but it is required that only one of them, -// 'srcLiveOutStoreOp', has output edges. -// Returns true if 'dstNode's read/write region to 'memref' is a super set of -// 'srcNode's write region to 'memref' and 'srcId' has only one output edge. -// TODO: Generalize this to handle more live in/out cases. -static bool -canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, - AffineWriteOpInterface srcLiveOutStoreOp, - MemRefDependenceGraph *mdg) { - assert(srcLiveOutStoreOp && "Expected a valid store op"); - auto *dstNode = mdg->getNode(dstId); - Value memref = srcLiveOutStoreOp.getMemRef(); - // Return false if 'srcNode' has more than one output edge on 'memref'. - if (mdg->getOutEdgeCount(srcId, memref) > 1) - return false; - - // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOp' on 'memref'. - MemRefRegion srcWriteRegion(srcLiveOutStoreOp.getLoc()); - if (failed(srcWriteRegion.compute(srcLiveOutStoreOp, /*loopDepth=*/0))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute MemRefRegion for source operation\n."); - return false; - } - SmallVector srcShape; - // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. - // by 'srcStoreOp' at depth 'dstLoopDepth'. - Optional srcNumElements = - srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape); - if (!srcNumElements.hasValue()) - return false; - - // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'. - // TODO: Compute 'unionboundingbox' of all write regions (one for - // each store op in 'dstStoreOps'). - SmallVector dstStoreOps; - dstNode->getStoreOpsForMemref(memref, &dstStoreOps); - SmallVector dstLoadOps; - dstNode->getLoadOpsForMemref(memref, &dstLoadOps); - - auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0]; - MemRefRegion dstRegion(dstOpInst->getLoc()); - if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute MemRefRegion for dest operation\n."); - return false; - } - SmallVector dstShape; - // Query 'dstRegion' for 'dstShape' and 'dstNumElements'. - // by 'dstOpInst' at depth 'dstLoopDepth'. - Optional dstNumElements = - dstRegion.getConstantBoundingSizeAndShape(&dstShape); - if (!dstNumElements.hasValue()) - return false; - - // Return false if write region is not a superset of 'srcNodes' write - // region to 'memref'. - // TODO: Check the shape and lower bounds here too. - if (srcNumElements != dstNumElements) - return false; - - // Return false if 'memref' is used by a non-affine operation that is - // between node 'srcId' and node 'dstId'. - if (hasNonAffineUsersOnThePath(srcId, dstId, mdg)) - return false; - - return true; -} - // Checks the profitability of fusing a backwards slice of the loop nest // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. // The argument 'srcStoreOpInst' is used to calculate the storage reduction on @@ -1029,9 +1035,6 @@ // the largest computation slice at the maximal dst loop depth (closest to // the load) to minimize reuse distance and potentially enable subsequent // load/store forwarding. -// NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for -// the same memref as is written by 'srcOpInst', then the union of slice -// loop bounds is used to compute the slice and associated slice cost. // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop // nest, at which the src computation slice is inserted/fused. // NOTE: We attempt to maximize the dst loop depth, but there are cases @@ -1041,18 +1044,18 @@ // *) Compares the total cost of the unfused loop nests to the min cost fused // loop nest computed in the previous step, and returns true if the latter // is lower. +// TODO: Extend profitability analysis to support scenarios with multiple +// stores. static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, - ArrayRef dstLoadOpInsts, + AffineForOp dstForOp, ArrayRef depthSliceUnions, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold) { LLVM_DEBUG({ llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; - llvm::dbgs() << ' ' << *srcOpInst << " and destination op(s)\n"; - for (auto dstOpInst : dstLoadOpInsts) { - llvm::dbgs() << " " << *dstOpInst << "\n"; - }; + llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n"; + llvm::dbgs() << dstForOp << "\n"; }); if (maxLegalFusionDepth == 0) { @@ -1070,11 +1073,8 @@ return false; // Compute cost of dst loop nest. - SmallVector dstLoopIVs; - getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); - LoopNestStats dstLoopNestStats; - if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats)) + if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) return false; // Search for min cost value for 'dstLoopDepth'. At each value of @@ -1108,18 +1108,19 @@ int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue(); // Compute op instance count for the src loop nest. - uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats); + uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats); // Evaluate all depth choices for materializing the slice in the destination // loop nest. for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { + const ComputationSliceState &slice = depthSliceUnions[i - 1]; // Skip slice union if it wasn't computed for this depth. - if (depthSliceUnions[i - 1].isEmpty()) + if (slice.isEmpty()) continue; int64_t fusedLoopNestComputeCost; - if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0], - dstLoopNestStats, depthSliceUnions[i - 1], + if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp, + dstLoopNestStats, slice, &fusedLoopNestComputeCost)) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); continue; @@ -1131,11 +1132,11 @@ 1; // Determine what the slice write MemRefRegion would be, if the src loop - // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst - // loop nest at loop depth 'i'. + // nest slice 'slice' were to be inserted into the dst loop nest at loop + // depth 'i'. MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, - &depthSliceUnions[i - 1]))) { + &slice))) { LLVM_DEBUG(llvm::dbgs() << "Failed to compute slice write region at loopDepth: " << i << "\n"); @@ -1218,7 +1219,7 @@ << "\n fused loop nest compute cost: " << minFusedLoopNestComputeCost << "\n"); - auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]); + auto dstMemSize = getMemoryFootprintBytes(dstForOp); auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); Optional storageReduction = None; @@ -1322,8 +1323,6 @@ MemRefDependenceGraph *mdg; // Worklist of graph nodes visited during the fusion pass. SmallVector worklist; - // Set of graph nodes which are present on the worklist. - llvm::SmallDenseSet worklistSet; // Parameter for local buffer size threshold. unsigned localBufSizeThreshold; // Parameter for fast memory space. @@ -1344,16 +1343,14 @@ fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion), computeToleranceThreshold(computeToleranceThreshold) {} - // Initializes 'worklist' with nodes from 'mdg' + /// Initializes 'worklist' with nodes from 'mdg'. void init() { // TODO: Add a priority queue for prioritizing nodes by different // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). worklist.clear(); - worklistSet.clear(); for (auto &idAndNode : mdg->nodes) { const Node &node = idAndNode.second; worklist.push_back(node.id); - worklistSet.insert(node.id); } } @@ -1372,11 +1369,11 @@ } void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { + LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); init(); while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); - worklistSet.erase(dstId); // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) @@ -1386,114 +1383,85 @@ // Skip if 'dstNode' is not a loop nest. if (!isa(dstNode->op)) continue; + + LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); + // Sink sequential loops in 'dstNode' (and thus raise parallel loops) // while preserving relative order. This can increase the maximum loop // depth at which we can fuse a slice of a producer loop nest into a // consumer loop nest. sinkSequentialLoops(dstNode); - - SmallVector loads = dstNode->loads; - SmallVector dstLoadOpInsts; - DenseSet visitedMemrefs; - while (!loads.empty()) { - // Get memref of load on top of the stack. - auto memref = cast(loads.back()).getMemRef(); - if (visitedMemrefs.count(memref) > 0) - continue; - visitedMemrefs.insert(memref); - // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'. - moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts); - // Skip if no input edges along which to fuse. - if (mdg->inEdges.count(dstId) == 0) - continue; - // Iterate through in-edges for 'dstId' and src node id for any - // edges on 'memref'. - SmallVector srcNodeIds; - for (auto &srcEdge : mdg->inEdges[dstId]) { - // Skip 'srcEdge' if not for 'memref'. - if (srcEdge.value != memref) - continue; - srcNodeIds.push_back(srcEdge.id); - } - for (unsigned srcId : srcNodeIds) { - // Skip if this node was removed (fused into another node). - if (mdg->nodes.count(srcId) == 0) - continue; + auto dstAffineForOp = cast(dstNode->op); + + // Try to fuse 'dstNode' with candidate producer loops until a fixed point + // is reached. Fusing two loops may expose new fusion opportunities. + bool dstNodeChanged; + do { + // Gather src loop candidates for 'dstNode' and visit them in "quasi" + // reverse program order to minimize the number of iterations needed to + // reach the fixed point. Note that this is a best effort approach since + // 'getProducerCandidates' does not always guarantee that program order + // in 'srcIdCandidates'. + dstNodeChanged = false; + SmallVector srcIdCandidates; + getProducerCandidates(dstId, mdg, srcIdCandidates); + + for (unsigned srcId : llvm::reverse(srcIdCandidates)) { // Get 'srcNode' from which to attempt fusion into 'dstNode'. auto *srcNode = mdg->getNode(srcId); - // Skip if 'srcNode' is not a loop nest. - if (!isa(srcNode->op)) + auto srcAffineForOp = cast(srcNode->op); + LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId + << " for dst loop " << dstId << "\n"); + + DenseSet producerConsumerMemrefs; + gatherProducerConsumerMemrefs(srcId, dstId, mdg, + producerConsumerMemrefs); + + // Skip if 'srcNode' out edge count on any memref is greater than + // 'maxSrcUserCount'. + if (any_of(producerConsumerMemrefs, [&](Value memref) { + return mdg->getOutEdgeCount(srcNode->id, memref) > + maxSrcUserCount; + })) continue; - // Skip if 'srcNode' has more than one live-out store to a - // function-local memref. - // TODO: Support more generic multi-output src loop nests - // fusion. - auto srcStoreOp = mdg->getUniqueOutgoingStore(srcNode); - if (!srcStoreOp) { - // Get the src store op at the deepest loop depth. - // We will use 'LoopFusionUtils::canFuseLoops' to check fusion - // feasibility for loops with multiple stores. - unsigned maxLoopDepth = 0; - for (auto *op : srcNode->stores) { - auto storeOp = cast(op); - if (storeOp.getMemRef() != memref) { - srcStoreOp = nullptr; - break; - } - unsigned loopDepth = getNestingDepth(storeOp); - if (loopDepth > maxLoopDepth) { - maxLoopDepth = loopDepth; - srcStoreOp = storeOp; - } - } - if (!srcStoreOp) - continue; - } - // Unique outgoing store found must write to 'memref' since 'memref' - // is the one that established the producer-consumer relationship - // between 'srcNode' and 'dstNode'. - assert(srcStoreOp.getMemRef() == memref && - "Found store to unexpected memref"); - - // Skip if 'srcNode' writes to any live in or escaping memrefs, - // and cannot be fused. - bool writesToLiveInOrOut = - mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); - if (writesToLiveInOrOut && - !canFuseSrcWhichWritesToLiveOut(srcId, dstId, srcStoreOp, mdg)) + // Gather memrefs in 'srcNode' that are written and escape to the + // function (e.g., memref function arguments, returned memrefs, + // memrefs passed to function calls, etc.). + DenseSet srcEscapingMemRefs; + gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); + + // Skip if there are non-affine operations in between the 'srcNode' + // and 'dstNode' using their memrefs. If so, we wouldn't be able to + // compute a legal insertion point for now. 'srcNode' and 'dstNode' + // memrefs with non-affine operation users would be considered + // escaping memrefs so we can limit this check to only scenarios with + // escaping memrefs. + if (!srcEscapingMemRefs.empty() && + hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { + LLVM_DEBUG( + llvm::dbgs() + << "Can't fuse: non-affine users in between the loops\n."); continue; - - // Don't create a private memref if 'writesToLiveInOrOut'. - bool createPrivateMemref = !writesToLiveInOrOut; - // Don't create a private memref if 'srcNode' has in edges on - // 'memref', or if 'dstNode' has out edges on 'memref'. - if (mdg->getIncomingMemRefAccesses(srcNode->id, memref) > 0 || - mdg->getOutEdgeCount(dstNode->id, memref) > 0) { - createPrivateMemref = false; } - // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. - if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) - continue; - // Compute an operation list insertion point for the fused loop // nest which preserves dependences. - Operation *insertPointInst = + Operation *fusedLoopInsPoint = mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); - if (insertPointInst == nullptr) + if (fusedLoopInsPoint == nullptr) continue; - auto srcAffineForOp = cast(srcNode->op); - auto dstAffineForOp = cast(dstNode->op); - - // Compute the innermost common loop depth for dstNode loads/stores. + // Compute the innermost common loop depth for dstNode + // producer-consumer loads/stores. SmallVector dstMemrefOps; for (Operation *op : dstNode->loads) - if (cast(op).getMemRef() == memref) + if (producerConsumerMemrefs.count( + cast(op).getMemRef()) > 0) dstMemrefOps.push_back(op); for (Operation *op : dstNode->stores) - if (cast(op).getMemRef() == memref) + if (producerConsumerMemrefs.count( + cast(op).getMemRef())) dstMemrefOps.push_back(op); unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); @@ -1502,7 +1470,7 @@ unsigned maxLegalFusionDepth = 0; SmallVector depthSliceUnions; depthSliceUnions.resize(dstLoopDepthTest); - FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref); + FusionStrategy strategy(FusionStrategy::ProducerConsumer); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { FusionResult result = mlir::canFuseLoops( srcAffineForOp, dstAffineForOp, @@ -1512,27 +1480,82 @@ maxLegalFusionDepth = i; } - // Skip if fusion is not feasible at any loop depths. - if (maxLegalFusionDepth == 0) + if (maxLegalFusionDepth == 0) { + LLVM_DEBUG(llvm::dbgs() + << "Can't fuse: fusion is not legal at any depth\n"); continue; + } // Check if fusion would be profitable. We skip profitability analysis // for maximal fusion since we already know the maximal legal depth to // fuse. unsigned bestDstLoopDepth = maxLegalFusionDepth; - if (!maximalFusion && - !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, - depthSliceUnions, maxLegalFusionDepth, - &bestDstLoopDepth, computeToleranceThreshold)) - continue; + if (!maximalFusion) { + // Retrieve producer stores from the src loop. + SmallVector producerStores; + for (Operation *op : srcNode->stores) + if (producerConsumerMemrefs.count( + cast(op).getMemRef())) + producerStores.push_back(op); + + // TODO: Suppport multiple producer stores in profitability + // analysis. We limit profitability analysis to only scenarios with + // a single producer store for now. Note that some multi-store + // producer scenarios will still go through profitability analysis + // if only one of the stores is involved the producer-consumer + // relationship of the candidate loops. + assert(producerStores.size() > 0 && "Expected producer store"); + if (producerStores.size() > 1) + LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " + "supported for this case\n"); + else if (!isFusionProfitable(producerStores[0], producerStores[0], + dstAffineForOp, depthSliceUnions, + maxLegalFusionDepth, &bestDstLoopDepth, + computeToleranceThreshold)) + continue; + } assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); - assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && - "Missing slice union for depth"); + ComputationSliceState &bestSlice = + depthSliceUnions[bestDstLoopDepth - 1]; + assert(!bestSlice.isEmpty() && "Missing slice union for depth"); + + // Determine if 'srcId' can be removed after fusion, taking into + // account remaining dependences, escaping memrefs and the fusion + // insertion point. + bool removeSrcNode = canRemoveSrcNodeAfterFusion( + srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, + mdg); + + DenseSet privateMemrefs; + for (Value memref : producerConsumerMemrefs) { + // Don't create a private memref if 'srcNode' writes to escaping + // memrefs. + if (srcEscapingMemRefs.count(memref) > 0) + continue; + + // Don't create a private memref if 'srcNode' has in edges on + // 'memref' or 'dstNode' has out edges on 'memref'. + if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 || + mdg->getOutEdgeCount(dstId, memref) > 0) + continue; + + // If 'srcNode' will be removed but it has out edges on 'memref' to + // nodes other than 'dstNode', we have to preserve dependences and + // cannot create a private memref. + if (removeSrcNode && + any_of(mdg->outEdges[srcId], [&](const auto &edge) { + return edge.value == memref && edge.id != dstId; + })) + continue; + + // Create a private version of this memref. + privateMemrefs.insert(memref); + } // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - fuseLoops(srcAffineForOp, dstAffineForOp, - depthSliceUnions[bestDstLoopDepth - 1]); + fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); + dstNodeChanged = true; LLVM_DEBUG(llvm::dbgs() << "Fused src loop " << srcId << " into dst loop " << dstId @@ -1540,92 +1563,63 @@ << dstAffineForOp << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. - if (insertPointInst != dstAffineForOp.getOperation()) - dstAffineForOp->moveBefore(insertPointInst); + if (fusedLoopInsPoint != dstAffineForOp.getOperation()) + dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint); // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, memref, - createPrivateMemref); - - // Collect slice loop stats. - LoopNestStateCollector dstForCollector; - dstForCollector.collect(dstAffineForOp); - if (createPrivateMemref) { - // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector storesForMemref; - for (auto *storeOpInst : dstForCollector.storeOpInsts) { - if (cast(storeOpInst).getMemRef() == - memref) - storesForMemref.push_back(storeOpInst); + mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, + removeSrcNode); + + // Create private memrefs. + if (!privateMemrefs.empty()) { + // Gather stores for all the private-to-be memrefs. + DenseMap> privateMemRefToStores; + dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { + Value storeMemRef = storeOp.getMemRef(); + if (privateMemrefs.count(storeMemRef) > 0) + privateMemRefToStores[storeMemRef].push_back( + storeOp.getOperation()); + }); + + // Replace original memrefs with private memrefs. Note that all the + // loads and stores on these memrefs will be replaced with a new + // loads and stores. Any reference to the original ones becomes + // invalid after this point. + for (auto &memrefToStoresPair : privateMemRefToStores) { + // TODO: Use union of memref write regions to compute + // private memref footprint. + SmallVector &storesForMemref = + memrefToStoresPair.second; + Value newMemRef = createPrivateMemRef( + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = + mdg->addNode(newMemRef.getDefiningOp()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); } - // TODO: Use union of memref write regions to compute - // private memref footprint. - auto newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth, - fastMemorySpace, localBufSizeThreshold); - visitedMemrefs.insert(newMemRef); - // Create new node in dependence graph for 'newMemRef' alloc op. - unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); - // Add edge from 'newMemRef' node to dstNode. - mdg->addEdge(newMemRefNodeId, dstId, newMemRef); } // Collect dst loop stats after memref privatization transformation. LoopNestStateCollector dstLoopCollector; dstLoopCollector.collect(dstAffineForOp.getOperation()); - // Add new load ops to current Node load op list 'loads' to continue - // fusing based on new operands. - for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - // NOTE: Change 'loads' to a hash set in case efficiency is an - // issue. We still use a vector since it's expected to be small. - if (!llvm::is_contained(loads, loadOpInst)) - loads.push_back(loadOpInst); - } - // Clear visited memrefs after fusion so that previously visited src - // nodes are considered for fusion again in the context of the new - // fused node. - // TODO: This shouldn't be necessary if we visited candidates in the - // dependence graph in post-order or once we fully support multi-store - // producers. Currently, in a multi-store producer scenario such as - // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing - // edges. However, after fusing B+C, A has a single outgoing edge and - // can be fused if we revisit it in the context of the new fused B+C - // node. - visitedMemrefs.clear(); - // Clear and add back loads and stores. mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts); - // Remove old src loop nest if it no longer has outgoing dependence - // edges, and if it does not write to a memref which escapes the - // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been - // fused into 'dstNode' and write region of 'dstNode' covers the write - // region of 'srcNode', and 'srcNode' has no other users so it is safe - // to remove. - if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { - mdg->removeNode(srcNode->id); - srcNode->op->erase(); - } else { - // Add remaining users of 'oldMemRef' back on the worklist (if not - // already there), as its replacement with a local/private memref - // has reduced dependences on 'oldMemRef' which may have created new - // fusion opportunities. - if (mdg->outEdges.count(srcNode->id) > 0) { - SmallVector oldOutEdges = - mdg->outEdges[srcNode->id]; - for (auto &outEdge : oldOutEdges) { - if (outEdge.value == memref && - worklistSet.count(outEdge.id) == 0) { - worklist.push_back(outEdge.id); - worklistSet.insert(outEdge.id); - } - } - } + + if (removeSrcNode) { + LLVM_DEBUG(llvm::dbgs() + << "Removing src loop " << srcId << " after fusion\n"); + // srcNode is no longer valid after it is removed from mdg. + srcAffineForOp.erase(); + mdg->removeNode(srcId); + srcNode = nullptr; } } - } + } while (dstNodeChanged); } } @@ -1636,7 +1630,6 @@ while (!worklist.empty()) { unsigned dstId = worklist.back(); worklist.pop_back(); - worklistSet.erase(dstId); // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) @@ -1698,7 +1691,7 @@ SmallVector depthSliceUnions; depthSliceUnions.resize(dstLoopDepthTest); unsigned maxLegalFusionDepth = 0; - FusionStrategy strategy(FusionStrategy::Sibling, memref); + FusionStrategy strategy(memref); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { FusionResult result = mlir::canFuseLoops( sibAffineForOp, dstAffineForOp, @@ -1712,10 +1705,10 @@ if (maxLegalFusionDepth == 0) continue; - unsigned bestDstLoopDepth = dstLoopDepthTest; + unsigned bestDstLoopDepth = maxLegalFusionDepth; if (!maximalFusion) { // Check if fusion would be profitable. - if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, + if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp, depthSliceUnions, maxLegalFusionDepth, &bestDstLoopDepth, computeToleranceThreshold)) continue; diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -191,11 +191,8 @@ /// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. // TODO: Generalize this check for sibling and more generic fusion scenarios. // TODO: Support forward slice fusion. -static unsigned getMaxLoopDepth(ArrayRef dstOps, - FusionStrategy fusionStrategy) { - assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer && - "Fusion strategy not supported"); - +static unsigned getMaxLoopDepth(ArrayRef srcOps, + ArrayRef dstOps) { if (dstOps.empty()) // Expected at least one memory operation. // TODO: Revisit this case with a specific example. @@ -203,15 +200,14 @@ // Filter out ops in 'dstOps' that do not use the producer-consumer memref so // that they are not considered for analysis. - // TODO: Currently, we pass the producer-consumer memref through - // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we - // generalize the algorithm. + DenseSet producerConsumerMemrefs; + gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs); SmallVector targetDstOps; for (Operation *dstOp : dstOps) { auto loadOp = dyn_cast(dstOp); Value memref = loadOp ? loadOp.getMemRef() : cast(dstOp).getMemRef(); - if (memref == fusionStrategy.memref) + if (producerConsumerMemrefs.count(memref) > 0) targetDstOps.push_back(dstOp); } @@ -308,10 +304,10 @@ // loop dependences. // TODO: Enable this check for sibling and more generic loop fusion // strategies. - if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) { + if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) { // TODO: 'getMaxLoopDepth' does not support forward slice fusion. assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); - if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) { + if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); return FusionResult::FailFusionDependence; } @@ -324,7 +320,7 @@ // Filter out ops in 'opsA' to compute the slice union based on the // assumptions made by the fusion strategy. SmallVector strategyOpsA; - switch (fusionStrategy.strategy) { + switch (fusionStrategy.getStrategy()) { case FusionStrategy::Generic: // Generic fusion. Take into account all the memory operations to compute // the slice union. @@ -332,10 +328,9 @@ break; case FusionStrategy::ProducerConsumer: // Producer-consumer fusion (AffineLoopFusion pass) only takes into - // account stores to 'memref' in 'srcForOp' to compute the slice union. + // account stores in 'srcForOp' to compute the slice union. for (Operation *op : opsA) { - auto store = dyn_cast(op); - if (store && store.getMemRef() == fusionStrategy.memref) + if (isa(op)) strategyOpsA.push_back(op); } break; @@ -344,7 +339,7 @@ // to 'memref' in 'srcForOp' to compute the slice union. for (Operation *op : opsA) { auto load = dyn_cast(op); - if (load && load.getMemRef() == fusionStrategy.memref) + if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef()) strategyOpsA.push_back(op); } break; @@ -628,3 +623,23 @@ /*tripCountOverrideMap=*/nullptr, &computeCostMap); return true; } + +/// Returns in 'producerConsumerMemrefs' the memrefs involved in a +/// producer-consumer dependence between write ops in 'srcOps' and read ops in +/// 'dstOps'. +void mlir::gatherProducerConsumerMemrefs( + ArrayRef srcOps, ArrayRef dstOps, + DenseSet &producerConsumerMemrefs) { + // Gather memrefs from stores in 'srcOps'. + DenseSet srcStoreMemRefs; + for (Operation *op : srcOps) + if (auto storeOp = dyn_cast(op)) + srcStoreMemRefs.insert(storeOp.getMemRef()); + + // Compute the intersection between memrefs from stores in 'srcOps' and + // memrefs from loads in 'dstOps'. + for (Operation *op : dstOps) + if (auto loadOp = dyn_cast(op)) + if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0) + producerConsumerMemrefs.insert(loadOp.getMemRef()); +} diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -364,8 +364,8 @@ // ----- -// CHECK-LABEL: func @should_fuse_with_private_memref_if_top_level_access() { -func @should_fuse_with_private_memref_if_top_level_access() { +// CHECK-LABEL: func @should_fuse_if_top_level_access() { +func @should_fuse_if_top_level_access() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -378,14 +378,45 @@ %c0 = constant 4 : index %v1 = affine.load %m[%c0] : memref<10xf32> - // Top-level load to '%{{.*}}' should prevent fusion. - // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // Top-level load to '%m' should prevent creating a private memref but + // loop nests should be fused and '%i0' should be removed. + // CHECK: %[[m:.*]] = alloc() : memref<10xf32> + // CHECK-NOT: alloc + + // CHECK: affine.for %[[i1:.*]] = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %[[m]][%[[i1]]] : memref<10xf32> + // CHECK-NEXT: affine.load %[[m]][%[[i1]]] : memref<10xf32> // CHECK-NEXT: } - // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK: affine.load %[[m]][%{{.*}}] : memref<10xf32> + return +} + +// ----- + +// CHECK-LABEL: func @should_fuse_but_not_remove_src() { +func @should_fuse_but_not_remove_src() { + %m = alloc() : memref<100xf32> + %cf7 = constant 7.0 : f32 + + affine.for %i0 = 0 to 100 { + affine.store %cf7, %m[%i0] : memref<100xf32> + } + affine.for %i1 = 0 to 17 { + %v0 = affine.load %m[%i1] : memref<100xf32> + } + %v1 = affine.load %m[99] : memref<100xf32> + + // Loop '%i0' and '%i1' should be fused but '%i0' shouldn't be removed to + // preserve the dependence with the top-level access. + // CHECK: affine.for %{{.*}} = 0 to 100 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<100xf32> + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 17 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: } + // CHECK-NEXT: affine.load %{{.*}}[99] : memref<100xf32> + // CHECK-NEXT: return return } @@ -1110,8 +1141,8 @@ // ----- -// CHECK-LABEL: func @should_not_fuse_live_out_arg(%{{.*}}: memref<10xf32>) { -func @should_not_fuse_live_out_arg(%arg0: memref<10xf32>) { +// CHECK-LABEL: func @should_fuse_live_out_arg_but_preserve_src_loop(%{{.*}}: memref<10xf32>) { +func @should_fuse_live_out_arg_but_preserve_src_loop(%arg0: memref<10xf32>) { %cf7 = constant 7.0 : f32 affine.for %i0 = 0 to 10 { @@ -1129,6 +1160,7 @@ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 9 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -1160,8 +1192,8 @@ // ----- -// CHECK-LABEL: func @should_not_fuse_escaping_memref() -> memref<10xf32> -func @should_not_fuse_escaping_memref() -> memref<10xf32> { +// CHECK-LABEL: func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> +func @should_fuse_escaping_memref_but_preserve_src_loop() -> memref<10xf32> { %cf7 = constant 7.0 : f32 %m = alloc() : memref<10xf32> affine.for %i0 = 0 to 10 { @@ -1170,19 +1202,21 @@ affine.for %i1 = 0 to 9 { %v0 = affine.load %m[%i1] : memref<10xf32> } - // This tests that the loop nest '%{{.*}}' should not be removed after fusion - // because it writes to memref '%{{.*}}' which is returned by the function. + // This tests that the loop nest '%i0' should not be removed after fusion + // because it writes to memref '%m', which is returned by the function, and + // the '%i1' memory region does not cover '%i0' memory region. + // CHECK-DAG: alloc() : memref<10xf32> // CHECK: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 9 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %{{.*}} : memref<10xf32> return %m : memref<10xf32> } - // ----- // This should fuse with the %in becoming a 1x1x1. @@ -1230,7 +1264,7 @@ // ----- -func @should_not_fuse_multi_output_producer() { +func @should_fuse_multi_output_producer() { %a = alloc() : memref<10xf32> %b = alloc() : memref<10xf32> @@ -1246,12 +1280,10 @@ } // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: } - // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1504,8 +1536,8 @@ // ----- -// CHECK-LABEL: func @should_fuse_after_private_memref_creation() { -func @should_fuse_after_private_memref_creation() { +// CHECK-LABEL: func @should_fuse_only_two_loops_and_remove_producer() { +func @should_fuse_only_two_loops_and_remove_producer() { %a = alloc() : memref<10xf32> %b = alloc() : memref<10xf32> @@ -1525,18 +1557,21 @@ // On the first visit to '%i2', the fusion algorithm can not fuse loop nest // '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on - // '%i1'. However, once the loop nest '%i0' is fused into '%i1' with a - // private memref, the dependence between '%i0' and '%i1' on memref '%a' no - // longer exists, so '%i0' can now be fused into '%i2'. - + // '%i1'. Then, '%i0' is fused into '%i1' and no private memref is created for + // memref '%a' to be able to remove '%i0' and still preserve the depencence on + // '%a' with '%i2'. + // TODO: Alternatively, we could fuse '%i0' into '%i1' with a private memref, + // the dependence between '%i0' and '%i1' on memref '%a' would no longer exist, + // and '%i0' could be fused into '%i2' as well. Note that this approach would + // duplicate the computation in loop nest '%i0' to loop nests '%i1' and '%i2', + // which would limit its profitability. // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return @@ -2220,7 +2255,7 @@ } } - // CHECK: affine.for %{{.*}} = 0 to 1024 { + // CHECK: affine.for %{{.*}} = 0 to 1024 { // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 { // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 { // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32> @@ -2311,8 +2346,8 @@ } // CHECK: affine.for %[[i0:.*]] = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%[[i0]]] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -2373,12 +2408,11 @@ // ----- -// Verify that 'fuseProducerConsumerNodes' doesn't fuse a producer loop with -// a store that has multiple outgoing edges. Sibling loop fusion should not fuse -// any of these loops due to dependencies on external memref '%a'. +// Verify that 'fuseProducerConsumerNodes' fuse a producer loop with a store +// that has multiple outgoing edges. -// CHECK-LABEL: func @should_not_fuse_multi_outgoing_edge_store_producer1 -func @should_not_fuse_multi_outgoing_edge_store_producer1(%a : memref<1xf32>) { +// CHECK-LABEL: func @should_fuse_multi_outgoing_edge_store_producer +func @should_fuse_multi_outgoing_edge_store_producer(%a : memref<1xf32>) { %cst = constant 0.000000e+00 : f32 affine.for %arg0 = 0 to 1 { affine.store %cst, %a[%arg0] : memref<1xf32> @@ -2391,9 +2425,12 @@ affine.for %arg0 = 0 to 1 { %0 = affine.load %a[%arg0] : memref<1xf32> } - // CHECK: affine.for %{{.*}} = 0 to 1 - // CHECK: affine.for %{{.*}} = 0 to 1 - // CHECK: affine.for %{{.*}} = 0 to 1 + // CHECK: affine.for %{{.*}} = 0 to 1 { + // CHECK-NEXT: affine.store + // CHECK-NEXT: affine.load + // CHECK-NEXT: affine.load + // CHECK-NEXT: } + return } @@ -2663,3 +2700,109 @@ // MAXIMAL: affine.for // MAXIMAL-NEXT: affine.for // MAXIMAL-NOT: affine.for +// MAXIMAL: return + +// ----- + +// CHECK-LABEL: func @should_fuse_multi_store_producer_and_privatize_memfefs +func @should_fuse_multi_store_producer_and_privatize_memfefs() { + %a = alloc() : memref<10xf32> + %b = alloc() : memref<10xf32> + %c = alloc() : memref<10xf32> + %cst = constant 0.000000e+00 : f32 + affine.for %arg0 = 0 to 10 { + affine.store %cst, %a[%arg0] : memref<10xf32> + affine.store %cst, %b[%arg0] : memref<10xf32> + affine.store %cst, %c[%arg0] : memref<10xf32> + %0 = affine.load %c[%arg0] : memref<10xf32> + } + + affine.for %arg0 = 0 to 10 { + %0 = affine.load %a[%arg0] : memref<10xf32> + } + + affine.for %arg0 = 0 to 10 { + %0 = affine.load %b[%arg0] : memref<10xf32> + } + + // All the memrefs should be privatized except '%c', which is not involved in + // the producer-consumer fusion. + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // CHECK-NEXT: } + + return +} + +// ----- + +func @should_fuse_multi_store_producer_with_scaping_memrefs_and_remove_src( + %a : memref<10xf32>, %b : memref<10xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 10 { + affine.store %cst, %a[%i0] : memref<10xf32> + affine.store %cst, %b[%i0] : memref<10xf32> + } + + affine.for %i1 = 0 to 10 { + %0 = affine.load %a[%i1] : memref<10xf32> + } + + affine.for %i2 = 0 to 10 { + %0 = affine.load %b[%i2] : memref<10xf32> + } + + // Producer loop '%i0' should be removed after fusion since fusion is maximal. + // No memref should be privatized since they escape the function. + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NOT: affine.for + + return +} + +// ----- + +func @should_fuse_multi_store_producer_with_scaping_memrefs_and_preserve_src( + %a : memref<10xf32>, %b : memref<10xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 10 { + affine.store %cst, %a[%i0] : memref<10xf32> + affine.store %cst, %b[%i0] : memref<10xf32> + } + + affine.for %i1 = 0 to 5 { + %0 = affine.load %a[%i1] : memref<10xf32> + } + + affine.for %i2 = 0 to 10 { + %0 = affine.load %b[%i2] : memref<10xf32> + } + + // Loops '%i0' and '%i2' should be fused first and '%i0' should be removed + // since fusion is maximal. Then the fused loop and '%i1' should be fused + // and the fused loop shouldn't be removed since fusion is not maximal. + // CHECK: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: } + // CHECK: affine.for %{{.*}} = 0 to 5 { + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NOT: affine.for + + return +}