diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -223,7 +223,7 @@ // Returns the graph node for 'forOp'. Node *getForOpNode(AffineForOp forOp) { for (auto &idAndNode : nodes) - if (idAndNode.second.op == forOp.getOperation()) + if (idAndNode.second.op == forOp) return &idAndNode.second; return nullptr; } @@ -710,10 +710,26 @@ producerConsumerMemrefs); } +/// A memref escapes the function if either: +/// 1. it is a function argument, or +/// 2. it is used by a non-affine op (e.g., std load/store, std +/// call, etc.) +/// FIXME: Support alias creating ops like memref view ops. +static bool isEscapingMemref(Value memref) { + // Check if 'memref' escapes because it's a block argument. + if (memref.isa()) + return true; + + // Check if 'memref' escapes through a non-affine op (e.g., std load/store, + // call op, etc.). This already covers aliases created from this. + for (Operation *user : memref.getUsers()) + if (!isa(*user)) + return true; + return false; +} + /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' -/// that escape the function. A memref escapes the function if either: -/// 1. It's a function argument, or -/// 2. It's used by a non-affine op (e.g., std load/store, std call, etc.) +/// that escape the function. void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet &escapingMemRefs) { auto *node = mdg->getNode(id); @@ -721,16 +737,8 @@ auto memref = cast(storeOpInst).getMemRef(); if (escapingMemRefs.count(memref)) continue; - // Check if 'memref' escapes because it's a block argument. - if (memref.isa()) { + if (isEscapingMemref(memref)) escapingMemRefs.insert(memref); - continue; - } - // Check if 'memref' escapes through a non-affine op (e.g., std load/store, - // call op, etc.). - for (Operation *user : memref.getUsers()) - if (!isa(*user)) - escapingMemRefs.insert(memref); } } @@ -831,8 +839,8 @@ getLoopIVs(*user, &loops); if (loops.empty()) continue; - assert(forToNodeMap.count(loops[0].getOperation()) > 0); - unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()]; + assert(forToNodeMap.count(loops[0]) > 0); + unsigned userLoopNestId = forToNodeMap[loops[0]]; addEdge(node.id, userLoopNestId, value); } } @@ -865,7 +873,7 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { assert(isa(node->op)); AffineForOp newRootForOp = sinkSequentialLoops(cast(node->op)); - node->op = newRootForOp.getOperation(); + node->op = newRootForOp; } // TODO: improve/complete this when we have target data. @@ -892,7 +900,7 @@ unsigned dstLoopDepth, Optional fastMemorySpace, uint64_t localBufSizeThreshold) { - auto *forInst = forOp.getOperation(); + Operation *forInst = forOp.getOperation(); // Create builder to insert alloc op just before 'forOp'. OpBuilder b(forInst); @@ -1627,8 +1635,8 @@ << dstAffineForOp << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. - if (fusedLoopInsPoint != dstAffineForOp.getOperation()) - dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint); + if (fusedLoopInsPoint != dstAffineForOp) + dstAffineForOp->moveBefore(fusedLoopInsPoint); // Update edges between 'srcNode' and 'dstNode'. mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, @@ -1641,8 +1649,7 @@ dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { Value storeMemRef = storeOp.getMemRef(); if (privateMemrefs.count(storeMemRef) > 0) - privateMemRefToStores[storeMemRef].push_back( - storeOp.getOperation()); + privateMemRefToStores[storeMemRef].push_back(storeOp); }); // Replace original memrefs with private memrefs. Note that all the @@ -1671,7 +1678,7 @@ // Collect dst loop stats after memref privatization transformation. LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstAffineForOp.getOperation()); + dstLoopCollector.collect(dstAffineForOp); // Clear and add back loads and stores. mdg->clearNodeLoadAndStores(dstNode->id); @@ -1797,7 +1804,7 @@ auto dstForInst = cast(dstNode->op); // Update operation position of fused loop nest (if needed). - if (insertPointInst != dstForInst.getOperation()) { + if (insertPointInst != dstForInst) { dstForInst->moveBefore(insertPointInst); } // Update data dependence graph state post fusion. @@ -1938,7 +1945,7 @@ // Collect dst loop stats after memref privatization transformation. auto dstForInst = cast(dstNode->op); LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstForInst.getOperation()); + dstLoopCollector.collect(dstForInst); // Clear and add back loads and stores mdg->clearNodeLoadAndStores(dstNode->id); mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,