diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Affine/Passes.h" -#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/Analysis/Utils.h" @@ -224,7 +223,7 @@ // Returns the graph node for 'forOp'. Node *getForOpNode(AffineForOp forOp) { for (auto &idAndNode : nodes) - if (idAndNode.second.op == forOp.getOperation()) + if (idAndNode.second.op == forOp) return &idAndNode.second; return nullptr; } @@ -711,27 +710,35 @@ producerConsumerMemrefs); } +/// A memref escapes the function if either: +/// 1. it is a function argument, or +/// 2. it is used by a non-affine op (e.g., std load/store, std +/// call, etc.) +/// FIXME: Support alias creating ops like memref view ops. +static bool isEscapingMemref(Value memref) { + // Check if 'memref' escapes because it's a block argument. + if (memref.isa()) + return true; + + // Check if 'memref' escapes through a non-affine op (e.g., std load/store, + // call op, etc.). This already covers aliases created from this. + for (Operation *user : memref.getUsers()) + if (!isa(*user)) + return true; + return false; +} + /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' -/// that escape the function. A memref escapes the function if either: -/// 1. It's a function argument, or -/// 2. It's used by a non-affine op (e.g., std load/store, std call, etc.) +/// that escape the function. void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet &escapingMemRefs) { auto *node = mdg->getNode(id); - for (auto *storeOpInst : node->stores) { - auto memref = cast(storeOpInst).getMemRef(); + for (Operation *storeOp : node->stores) { + auto memref = cast(storeOp).getMemRef(); if (escapingMemRefs.count(memref)) continue; - // Check if 'memref' escapes because it's a block argument. - if (memref.isa()) { + if (isEscapingMemref(memref)) escapingMemRefs.insert(memref); - continue; - } - // Check if 'memref' escapes through a non-affine op (e.g., std load/store, - // call op, etc.). - for (Operation *user : memref.getUsers()) - if (!isa(*user)) - escapingMemRefs.insert(memref); } } @@ -743,6 +750,8 @@ // dependence graph at a different depth. bool MemRefDependenceGraph::init(func::FuncOp f) { LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); + // Map from a memref to the set of ids of the nodes that have ops accessing + // the memref. DenseMap> memrefAccesses; // TODO: support multi-block functions. @@ -832,8 +841,8 @@ getLoopIVs(*user, &loops); if (loops.empty()) continue; - assert(forToNodeMap.count(loops[0].getOperation()) > 0); - unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()]; + assert(forToNodeMap.count(loops[0]) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0]]; addEdge(node.id, userLoopNestId, value); } } @@ -866,7 +875,7 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { assert(isa(node->op)); AffineForOp newRootForOp = sinkSequentialLoops(cast(node->op)); - node->op = newRootForOp.getOperation(); + node->op = newRootForOp; } // TODO: improve/complete this when we have target data. @@ -893,7 +902,7 @@ unsigned dstLoopDepth, Optional fastMemorySpace, uint64_t localBufSizeThreshold) { - auto *forInst = forOp.getOperation(); + Operation *forInst = forOp.getOperation(); // Create builder to insert alloc op just before 'forOp'. OpBuilder b(forInst); @@ -1418,6 +1427,10 @@ eraseUnusedMemRefAllocations(); } + /// Visit each node in the graph, and for each node, attempt to fuse it with + /// producer-consumer candidates. No fusion is performed when producers with a + /// user count greater than `maxSrcUserCount` for any of the memrefs involved + /// are encountered. void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); init(); @@ -1628,8 +1641,8 @@ << dstAffineForOp << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. - if (fusedLoopInsPoint != dstAffineForOp.getOperation()) - dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint); + if (fusedLoopInsPoint != dstAffineForOp) + dstAffineForOp->moveBefore(fusedLoopInsPoint); // Update edges between 'srcNode' and 'dstNode'. mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, @@ -1642,8 +1655,7 @@ dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { Value storeMemRef = storeOp.getMemRef(); if (privateMemrefs.count(storeMemRef) > 0) - privateMemRefToStores[storeMemRef].push_back( - storeOp.getOperation()); + privateMemRefToStores[storeMemRef].push_back(storeOp); }); // Replace original memrefs with private memrefs. Note that all the @@ -1672,7 +1684,7 @@ // Collect dst loop stats after memref privatization transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstAffineForOp.getOperation()); + dstLoopCollector.collect(dstAffineForOp); // Clear and add back loads and stores. mdg->clearNodeLoadAndStores(dstNode->id); @@ -1798,7 +1810,7 @@ auto dstForInst = cast(dstNode->op); // Update operation position of fused loop nest (if needed). - if (insertPointInst != dstForInst.getOperation()) { + if (insertPointInst != dstForInst) { dstForInst->moveBefore(insertPointInst); } // Update data dependence graph state post fusion. @@ -1939,7 +1951,7 @@ // Collect dst loop stats after memref privatization transformation. auto dstForInst = cast(dstNode->op); LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstForInst.getOperation()); + dstLoopCollector.collect(dstForInst); // Clear and add back loads and stores mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -13,20 +13,13 @@ #include "mlir/Dialect/Affine/LoopFusionUtils.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" -#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -113,7 +106,7 @@ } return WalkResult::advance(); } - for (auto value : op->getResults()) { + for (Value value : op->getResults()) { for (Operation *user : value.getUsers()) { SmallVector loops; // Check if any loop in loop nest surrounding 'user' is 'opB'. @@ -137,15 +130,12 @@ // dependences. Returns nullptr if no such insertion point is found. static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, AffineForOp dstForOp) { - bool isSrcForOpBeforeDstForOp = - srcForOp->isBeforeInBlock(dstForOp.getOperation()); + bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; - auto *firstDepOpA = - getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); - auto *lastDepOpB = - getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); + Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB); + Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB); // Block: // ... // |-- opA @@ -170,7 +160,7 @@ } // No dependences from 'opA' to operation in range ('opA', 'opB'), return // 'opB' insertion point. - return forOpB.getOperation(); + return forOpB; } // Gathers all load and store ops in loop nest rooted at 'forOp' into @@ -281,8 +271,7 @@ } // Check if 'srcForOp' precedes 'dstForOp' in 'block'. - bool isSrcForOpBeforeDstForOp = - srcForOp->isBeforeInBlock(dstForOp.getOperation()); + bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp); // 'forOpA' executes before 'forOpB' in 'block'. auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; @@ -315,8 +304,8 @@ } // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. - unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops( - *srcForOp.getOperation(), *dstForOp.getOperation()); + unsigned numCommonLoops = + mlir::getNumCommonSurroundingLoops(*srcForOp, *dstForOp); // Filter out ops in 'opsA' to compute the slice union based on the // assumptions made by the fusion strategy. @@ -539,8 +528,8 @@ int64_t opCount = stats.opCountMap[forOp] - 1; if (stats.loopMap.count(forOp) > 0) { for (auto childForOp : stats.loopMap[forOp]) { - opCount += getComputeCostHelper(childForOp.getOperation(), stats, - tripCountOverrideMap, computeCostMap); + opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap, + computeCostMap); } } // Add in additional op instances from slice (if specified in map). @@ -567,7 +556,7 @@ /// instance count (i.e. total number of operations in the loop body * loop /// trip count) for the entire loop nest. int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { - return getComputeCostHelper(forOp.getOperation(), stats, + return getComputeCostHelper(forOp, stats, /*tripCountOverrideMap=*/nullptr, /*computeCostMap=*/nullptr); } @@ -611,8 +600,8 @@ computeCostMap[insertPointParent] = -storeCount; // Subtract out any load users of 'storeMemrefs' nested below // 'insertPointParent'. - for (auto value : storeMemrefs) { - for (auto *user : value.getUsers()) { + for (Value memref : storeMemrefs) { + for (auto *user : memref.getUsers()) { if (auto loadOp = dyn_cast(user)) { SmallVector loops; // Check if any loop in loop nest surrounding 'user' is @@ -633,13 +622,13 @@ // Compute op instance count for the src loop nest with iteration slicing. int64_t sliceComputeCost = getComputeCostHelper( - srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap); + srcForOp, srcStats, &sliceTripCountMap, &computeCostMap); // Compute cost of fusion for this depth. computeCostMap[insertPointParent] = sliceComputeCost; *computeCost = - getComputeCostHelper(dstForOp.getOperation(), dstStats, + getComputeCostHelper(dstForOp, dstStats, /*tripCountOverrideMap=*/nullptr, &computeCostMap); return true; }