diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -948,6 +948,65 @@ return newMemRef; } +/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and +/// 'dstId'), if there is any non-affine operation accessing 'memref', return +/// false. Otherwise, return true. +static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, + Value memref, + MemRefDependenceGraph *mdg) { + auto *srcNode = mdg->getNode(srcId); + auto *dstNode = mdg->getNode(dstId); + Value::user_range users = memref.getUsers(); + // For each MemRefDependenceGraph's node that is between 'srcNode' and + // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any + // non-affine operation in the node accesses the 'memref'. + for (auto &idAndNode : mdg->nodes) { + Operation *op = idAndNode.second.op; + // Take care of operations between 'srcNode' and 'dstNode'. + if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) { + // Walk inside the operation to find any use of the memref. + // Interrupt the walk if found. + auto walkResult = op->walk([&](Operation *user) { + // Skip affine ops. + if (isMemRefDereferencingOp(*user)) + return WalkResult::advance(); + // Find a non-affine op that uses the memref. + if (llvm::is_contained(users, user)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return true; + } + } + return false; +} + +/// Check whether a memref value in node 'srcId' has a non-affine that +/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and +/// 'dstNode'). +static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, + MemRefDependenceGraph *mdg) { + // Collect memref values in node 'srcId'. + auto *srcNode = mdg->getNode(srcId); + llvm::SmallDenseSet memRefValues; + srcNode->op->walk([&](Operation *op) { + // Skip affine ops. + if (isa(op)) + return WalkResult::advance(); + for (Value v : op->getOperands()) + // Collect memref values only. + if (v.getType().isa()) + memRefValues.insert(v); + return WalkResult::advance(); + }); + // Looking for users between node 'srcId' and node 'dstId'. + for (Value memref : memRefValues) + if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg)) + return true; + return false; +} + // Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId' // may write to multiple memrefs but it is required that only one of them, // 'srcLiveOutStoreOp', has output edges. @@ -1008,6 +1067,12 @@ // TODO(andydavis) Check the shape and lower bounds here too. if (srcNumElements != dstNumElements) return false; + + // Return false if 'memref' is used by a non-affine operation that is + // between node 'srcId' and node 'dstId'. + if (hasNonAffineUsersOnThePath(srcId, dstId, mdg)) + return false; + return true; } @@ -1793,6 +1858,12 @@ } if (storeMemrefs.size() != 1) return false; + + // Skip if a memref value in one node is used by a non-affine memref + // access that lies between 'dstNode' and 'sibNode'. + if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) || + hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg)) + return false; return true; }; diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2570,3 +2570,67 @@ // CHECK-NEXT: affine.store %{{.*}}, %arg{{.*}}[%arg{{.*}}] : memref // CHECK-NEXT: } // CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @should_not_fuse_since_non_affine_users +func @should_not_fuse_since_non_affine_users(%in0 : memref<32xf32>, + %in1 : memref<32xf32>) { + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = addf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = load %in0[%d] : memref<32xf32> + %rhs = load %in1[%d] : memref<32xf32> + %add = subf %lhs, %rhs : f32 + store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = mulf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + return +} + +// CHECK: affine.for +// CHECK: addf +// CHECK: affine.for +// CHECK: subf +// CHECK: affine.for +// CHECK: mulf + +// ----- + +// CHECK-LABEL: func @should_not_fuse_since_top_level_non_affine_users +func @should_not_fuse_since_top_level_non_affine_users(%in0 : memref<32xf32>, + %in1 : memref<32xf32>) { + %sum = alloc() : memref + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = addf %lhs, %rhs : f32 + store %add, %sum[] : memref + affine.store %add, %in0[%d] : memref<32xf32> + } + %load_sum = load %sum[] : memref + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = mulf %lhs, %rhs : f32 + %sub = subf %add, %load_sum: f32 + affine.store %sub, %in0[%d] : memref<32xf32> + } + dealloc %sum : memref + return +} + +// CHECK: affine.for +// CHECK: addf +// CHECK: affine.for +// CHECK: mulf +// CHECK: subf