diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -1785,9 +1785,6 @@ // Currently findSiblingNodeToFuse searches for siblings with one load. assert(sibLoadOpInsts.size() == 1); Operation *sibLoadOpInst = sibLoadOpInsts[0]; - assert(!sibNode->stores.empty()); - // TODO: Choose the store which postdominates all other stores. - auto *sibStoreOpInst = sibNode->stores.back(); // Gather 'dstNode' load ops to 'memref'. SmallVector dstLoadOpInsts; @@ -1818,8 +1815,11 @@ unsigned bestDstLoopDepth = maxLegalFusionDepth; if (!maximalFusion) { - // Check if fusion would be profitable. - if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstAffineForOp, + // Check if fusion would be profitable. For sibling fusion, the sibling + // load op is treated as the src "store" op for fusion profitability + // purposes. The footprint of the load in the slice relative to the + // unfused source's determines reuse. + if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp, depthSliceUnions, maxLegalFusionDepth, &bestDstLoopDepth, computeToleranceThreshold)) continue; @@ -1875,13 +1875,13 @@ })) return false; - // Check that all stores are to the same memref. + // Check that all stores are to the same memref if any. DenseSet storeMemrefs; for (auto *storeOpInst : sibNode->stores) { storeMemrefs.insert( cast(storeOpInst).getMemRef()); } - if (storeMemrefs.size() != 1) + if (storeMemrefs.size() > 1) return false; // Skip if a memref value in one node is used by a non-affine memref diff --git a/mlir/test/Transforms/loop-fusion-2.mlir b/mlir/test/Transforms/loop-fusion-2.mlir --- a/mlir/test/Transforms/loop-fusion-2.mlir +++ b/mlir/test/Transforms/loop-fusion-2.mlir @@ -587,32 +587,32 @@ // MAXIMAL-NEXT: memref.alloc() : memref<2x2x3x3x16x1xf32> // MAXIMAL-NEXT: memref.alloc() : memref<144x4xf32> // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 { -// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 { -// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 { -// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 { -// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 { -// MAXIMAL-NEXT: affine.apply [[$MAP0]](%{{.*}}, %{{.*}}) -// MAXIMAL-NEXT: affine.apply [[$MAP1]](%{{.*}}, %{{.*}}) -// MAXIMAL-NEXT: affine.apply [[$MAP2]](%{{.*}}, %{{.*}}) -// MAXIMAL-NEXT: affine.apply [[$MAP3]](%{{.*}}, %{{.*}}) -// MAXIMAL-NEXT: affine.apply [[$MAP4]](%{{.*}}, %{{.*}}) -// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<2x2x3x3x16x1xf32> -// MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32> -// MAXIMAL-NEXT: } +// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 { +// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 { +// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 { +// MAXIMAL-NEXT: affine.apply [[$MAP0]](%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.apply [[$MAP1]](%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.apply [[$MAP2]](%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.apply [[$MAP3]](%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.apply [[$MAP4]](%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<2x2x3x3x16x1xf32> +// MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32> +// MAXIMAL-NEXT: } +// MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32> +// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 { // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 { // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 { -// MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}}) -// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32> -// MAXIMAL-NEXT: } -// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 { -// MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}}) -// MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<144x4xf32> +// MAXIMAL-NEXT: affine.apply [[$MAP8]](%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 - %{{.*}} + 15, 0] : memref<64x1xf32> // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } -// MAXIMAL-NEXT: affine.apply [[$MAP8]](%{{.*}}, %{{.*}}) -// MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 - %{{.*}} + 15, 0] : memref<64x1xf32> // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } +// MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 { +// MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}}) +// MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<144x4xf32> +// MAXIMAL-NEXT: } // MAXIMAL-NEXT: } // MAXIMAL-NEXT: } diff --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir --- a/mlir/test/Transforms/loop-fusion-4.mlir +++ b/mlir/test/Transforms/loop-fusion-4.mlir @@ -144,6 +144,22 @@ // ----- +// SIBLING-MAXIMAL-LABEL: func @sibling_load_only +func.func @sibling_load_only(%arg0: memref<10xf32>) { + affine.for %arg1 = 0 to 10 { + %0 = affine.load %arg0[%arg1] : memref<10xf32> + } + affine.for %arg1 = 0 to 10 { + %0 = affine.load %arg0[%arg1] : memref<10xf32> + } + // SIBLING-MAXIMAL-NEXT: affine.for + // SIBLING-MAXIMAL-NEXT: affine.load + // SIBLING-MAXIMAL-NEXT: affine.load + return +} + +// ----- + // PRODUCER-CONSUMER-LABEL: func @fusion_for_multiple_blocks() { func.func @fusion_for_multiple_blocks() { ^bb0: diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1206,13 +1206,9 @@ // Should create two new private memrefs customized to the shapes accessed // by loops %{{.*}} and %{{.*}}. // CHECK-DAG: memref.alloc() : memref<1xf32> - // CHECK-DAG: memref.alloc() : memref<1xf32> // CHECK: affine.for %{{.*}} = 0 to 17 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> - // CHECK-NEXT: } - // CHECK-NEXT: affine.for %{{.*}} = 0 to 82 { - // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> // CHECK-NEXT: } // CHECK-NEXT: return