diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1628,14 +1628,22 @@ // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - auto loadMemRef = - cast(loadOpInst).getMemRef(); // NOTE: Change 'loads' to a hash set in case efficiency is an // issue. We still use a vector since it's expected to be small. - if (visitedMemrefs.count(loadMemRef) == 0 && - !llvm::is_contained(loads, loadOpInst)) + if (!llvm::is_contained(loads, loadOpInst)) loads.push_back(loadOpInst); } + // Clear visited memrefs after fusion so that previously visited src + // nodes are considered for fusion again in the context of the new + // fused node. + // TODO: This shouldn't be necessary if we visited candidates in the + // dependence graph in post-order or once we fully support + // multi-store producers. Currently, in a multi-store producer + // scenario such as A->B, A->C, B->C, we fail to fuse A+B due to the + // multiple outgoing edges. However, after fusing B+C, A has a + // single outgoing edge and can be fused if we revisit it in the + // context of the new fused B+C node. + visitedMemrefs.clear(); // Clear and add back loads and stores. mdg->clearNodeLoadAndStores(dstNode->id); diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2493,3 +2493,45 @@ // CHECK-NEXT: affine.vector_load // CHECK-NEXT: affine.vector_store // CHECK-NOT: affine.for + +// ----- + +// CHECK-LABEL: func @multi_outgoing_edges +func @multi_outgoing_edges(%in0 : memref<32xf32>, + %in1 : memref<32xf32>) { + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = addf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = subf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = mulf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + affine.for %d = 0 to 32 { + %lhs = affine.load %in0[%d] : memref<32xf32> + %rhs = affine.load %in1[%d] : memref<32xf32> + %add = divf %lhs, %rhs : f32 + affine.store %add, %in0[%d] : memref<32xf32> + } + return +} + +// CHECK: affine.for +// CHECK-NOT: affine.for +// CHECK: addf +// CHECK-NOT: affine.for +// CHECK: subf +// CHECK-NOT: affine.for +// CHECK: mulf +// CHECK-NOT: affine.for +// CHECK: divf