diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -958,60 +958,63 @@ return newMemRef; } -// Walking from node 'srcId' to node 'dstId', if there is a use of 'memref' that -// is a non-affine operation, return false. +/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and +/// 'dstId'), if there is any non-affine operation accessing 'memref', return +/// false. Otherwise, return true. static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, Value memref, MemRefDependenceGraph *mdg) { auto *srcNode = mdg->getNode(srcId); auto *dstNode = mdg->getNode(dstId); auto users = memref.getUsers(); + // For each MemRefDependenceGraph's node that is between 'srcNode' and + // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any + // non-affine operation in the node accesses the 'memref'. for (auto &idAndNode : mdg->nodes) { - auto *op = idAndNode.second.op; - // Take care of operations between srcNode and dstNode. + Operation *op = idAndNode.second.op; + // Take care of operations between 'srcNode' and 'dstNode'. if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { // Walk inside the operation to find any use of the memref. - bool found = false; - op->walk([&](Operation *user) { + // Interrupt the walk if found. + auto walkResult = op->walk([&](Operation *user) { // Skip affine ops. if (isMemRefDereferencingOp(*user)) return WalkResult::advance(); // Find a non-affine op that uses the memref. if (std::find(users.begin(), users.end(), user) != users.end()) { - found = true; return WalkResult::interrupt(); } return WalkResult::advance(); }); - if (found) + if (walkResult.wasInterrupted()) return true; } } return false; } -// Check whether a value in node 'srcId' has a non-affine consumer that is -// between node 'srcId' and node 'dstId'. +/// Check whether a memref value in node 'srcId' has a non-affine that +/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and +/// 'dstNode'). static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, MemRefDependenceGraph *mdg) { - // Collect values in node 'srcId'. + // Collect memref values in node 'srcId'. auto *srcNode = mdg->getNode(srcId); - SmallVector values; + llvm::SmallDenseSet memRefValues; srcNode->op->walk([&](Operation *op) { - // Skip affine ops + // Skip affine ops. if (isa(op)) return WalkResult::advance(); - for (Value v : op->getOperands()) { - if (std::find(values.begin(), values.end(), v) == values.end()) - values.push_back(v); - } + for (Value v : op->getOperands()) + // Collect memref values only. + if (v.getType().isa()) + memRefValues.insert(v); return WalkResult::advance(); }); // Looking for users between node 'srcId' and node 'dstId'. - for (Value v : values) { - if (hasNonAffineUsersOnThePath(srcId, dstId, v, mdg)) + for (Value memref : memRefValues) + if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg)) return true; - } return false; } @@ -1867,8 +1870,8 @@ if (storeMemrefs.size() != 1) return false; - // Skip if a value in one node is used by a non-affine operation that lies - // between 'dstNode' and 'sibNode'. + // Skip if a memref value in one node is used by a non-affine memref + // access that lies between 'dstNode' and 'sibNode'. if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) return false;