Index: mlir/lib/Transforms/LoopFusion.cpp =================================================================== --- mlir/lib/Transforms/LoopFusion.cpp +++ mlir/lib/Transforms/LoopFusion.cpp @@ -958,6 +958,63 @@ return newMemRef; } +// Walking from node 'srcId' to node 'dstId', if there is a use of 'memref' that +// is a non-affine operation, return false. +static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, + Value memref, + MemRefDependenceGraph *mdg) { + auto *srcNode = mdg->getNode(srcId); + auto *dstNode = mdg->getNode(dstId); + auto users = memref.getUsers(); + for (auto &idAndNode : mdg->nodes) { + auto *op = idAndNode.second.op; + // Take care of operations between srcNode and dstNode. + if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { + // Walk inside the operation to find any use of the memref. + bool found = false; + op->walk([&](Operation *user) { + // Skip affine ops. + if (isMemRefDereferencingOp(*user)) + return WalkResult::advance(); + // Find a non-affine op that uses the memref. + if (std::find(users.begin(), users.end(), user) != users.end()) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (found) + return true; + } + } + return false; +} + +// Check whether a value in node 'srcId' has a non-affine consumer that is +// between node 'srcId' and node 'dstId'. +static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, + MemRefDependenceGraph *mdg) { + // Collect values in node 'srcId'. + auto *srcNode = mdg->getNode(srcId); + SmallVector values; + srcNode->op->walk([&](Operation *op) { + // Skip affine ops + if (isa(op)) + return WalkResult::advance(); + for (Value v : op->getOperands()) { + if (std::find(values.begin(), values.end(), v) == values.end()) + values.push_back(v); + } + return WalkResult::advance(); + }); + // Looking for users between node 'srcId' and node 'dstId'. + for (Value v : values) { + if (hasNonAffineUsersOnThePath(srcId, dstId, v, mdg)) + return true; + } + return false; +} + // Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' // may write to multiple memrefs but it is required that only one of them, // 'srcLiveOutStoreOp', has output edges. @@ -1018,6 +1075,12 @@ // TODO(andydavis) Check the shape and lower bounds here too. if (srcNumElements != dstNumElements) return false; + + // Return false if 'memref' is used by a non-affine operation that is + // between node 'srcId' and node 'dstId'. + if (hasNonAffineUsersOnThePath(srcId, dstId, mdg)) + return false; + return true; } @@ -1803,6 +1866,12 @@ } if (storeMemrefs.size() != 1) return false; + + // Skip if a value in one node is used by a non-affine operation that lies + // between 'dstNode' and 'sibNode'. + if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || + hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) + return false; return true; }; Index: mlir/test/Transforms/loop-fusion.mlir =================================================================== --- mlir/test/Transforms/loop-fusion.mlir +++ mlir/test/Transforms/loop-fusion.mlir @@ -2535,3 +2535,67 @@ // CHECK: mulf // CHECK-NOT: affine.for // CHECK: divf + +// ----- + +// CHECK-LABEL: func @should_not_fuse_since_non_affine_users +func @should_not_fuse_since_non_affine_users(%in0 : memref<32xf32>, + %in1 : memref<32xf32>) { + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = addf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = load %in0[%d] : memref<32xf32> + %rhs = load %in1[%d] : memref<32xf32> + %add = subf %lhs, %rhs : f32 + store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = mulf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + return +} + +// CHECK: affine.for +// CHECK: addf +// CHECK: affine.for +// CHECK: subf +// CHECK: affine.for +// CHECK: mulf + +// ----- + +// CHECK-LABEL: func @should_not_fuse_since_top_level_non_affine_users +func @should_not_fuse_since_top_level_non_affine_users(%in0 : memref<32xf32>, + %in1 : memref<32xf32>) { + %sum = alloc() : memref + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = addf %lhs, %rhs : f32 + store %add, %sum[] : memref + affine.store %add, %in0[%d] : memref<32xf32> + } + %load_sum = load %sum[] : memref + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = mulf %lhs, %rhs : f32 + %sub = subf %add, %load_sum: f32 + affine.store %sub, %in0[%d] : memref<32xf32> + } + dealloc %sum : memref + return +} + +// CHECK: affine.for +// CHECK: addf +// CHECK: affine.for +// CHECK: mulf +// CHECK: subf