diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1575,9 +1575,11 @@ DenseSet privateMemrefs; for (Value memref : producerConsumerMemrefs) { - // Don't create a private memref if 'srcNode' writes to escaping - // memrefs. - if (srcEscapingMemRefs.count(memref) > 0) + // If `memref` is an escaping one, do not create a private memref + // for it if the source is to be removed after fusion, or if the + // destination writes to `memref`. + if (srcEscapingMemRefs.count(memref) > 0 && + (removeSrcNode || dstNode->getStoreOpCount(memref) > 0)) continue; // Don't create a private memref if 'srcNode' has in edges on diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1157,11 +1157,11 @@ // in the fused loop nest, so complete live out data region would not // be written). // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}} : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 9 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}} : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}} : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return return @@ -1206,13 +1206,13 @@ // because it writes to memref '%m', which is returned by the function, and // the '%i1' memory region does not cover '%i0' memory region. - // CHECK-DAG: memref.alloc() : memref<10xf32> + // CHECK-DAG: memref.alloc() : memref<1xf32> // CHECK: affine.for %{{.*}} = 0 to 10 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}} : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: affine.for %{{.*}} = 0 to 9 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}} : memref<1xf32> + // CHECK-NEXT: affine.load %{{.*}} : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return %{{.*}} : memref<10xf32> return %m : memref<10xf32> @@ -2770,7 +2770,7 @@ // ----- -func @should_fuse_multi_store_producer_with_scaping_memrefs_and_remove_src( +func @should_fuse_multi_store_producer_with_escaping_memrefs_and_remove_src( %a : memref<10xf32>, %b : memref<10xf32>) { %cst = constant 0.000000e+00 : f32 affine.for %i0 = 0 to 10 { @@ -2787,7 +2787,8 @@ } // Producer loop '%i0' should be removed after fusion since fusion is maximal. - // No memref should be privatized since they escape the function. + // No memref should be privatized since they escape the function, and the + // producer is removed after fusion. // CHECK: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> @@ -2801,7 +2802,7 @@ // ----- -func @should_fuse_multi_store_producer_with_scaping_memrefs_and_preserve_src( +func @should_fuse_multi_store_producer_with_escaping_memrefs_and_preserve_src( %a : memref<10xf32>, %b : memref<10xf32>) { %cst = constant 0.000000e+00 : f32 affine.for %i0 = 0 to 10 { @@ -2826,10 +2827,10 @@ // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: } // CHECK: affine.for %{{.*}} = 0 to 5 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> - // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.store %{{.*}} : memref<1xf32> + // CHECK-NEXT: affine.store %{{.*}} : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}} : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}} : memref<1xf32> // CHECK-NEXT: } // CHECK-NOT: affine.for @@ -3071,6 +3072,37 @@ // ----- +// Test for source that writes to an escaping memref has two consumers. Fusion +// should create private memrefs in place of `%arg0` since source is not to be +// removed after fusion and destinations do not write to `%arg0`. This should +// enable both the consumers to benefit from fusion, which would not be possible +// if private memrefs were not created. +func @should_fuse_with_both_consumers_separately(%arg0: memref<10xf32>) { + %cf7 = constant 7.0 : f32 + affine.for %i0 = 0 to 10 { + affine.store %cf7, %arg0[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 7 { + %v0 = affine.load %arg0[%i1] : memref<10xf32> + } + affine.for %i1 = 5 to 9 { + %v0 = affine.load %arg0[%i1] : memref<10xf32> + } + return +} + +// CHECK-LABEL: func @should_fuse_with_both_consumers_separately +// CHECK: affine.for +// CHECK-NEXT: affine.store +// CHECK: affine.for +// CHECK-NEXT: affine.store +// CHECK-NEXT: affine.load +// CHECK: affine.for +// CHECK-NEXT: affine.store +// CHECK-NEXT: affine.load + +// ----- + // Fusion is avoided when the slice computed is invalid. Comments below describe // incorrect backward slice computation. Similar logic applies for forward slice // as well.