diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -207,7 +207,8 @@ // Check if src and dst loop bounds are the same. If not, we can guarantee // that the slice is not maximal. - if (srcLbResult != dstLbResult || srcUbResult != dstUbResult) + if (srcLbResult != dstLbResult || srcUbResult != dstUbResult || + srcLoop.getStep() != dstLoop.getStep()) return false; } diff --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir --- a/mlir/test/Transforms/loop-fusion-4.mlir +++ b/mlir/test/Transforms/loop-fusion-4.mlir @@ -71,6 +71,38 @@ // ----- +// Expects fusion of producer into consumer at depth 1 and source loop to not +// be removed due to difference in loop steps. +// PRODUCER-CONSUMER-LABEL: func @check_src_dst_step +func @check_src_dst_step(%m : memref<100xf32>, + %src: memref<100xf32>, + %out: memref<100xf32>) { + affine.for %i0 = 0 to 100 { + %r1 = affine.load %src[%i0]: memref<100xf32> + affine.store %r1, %m[%i0] : memref<100xf32> + } + affine.for %i2 = 0 to 100 step 2 { + %r2 = affine.load %m[%i2] : memref<100xf32> + affine.store %r2, %out[%i2] : memref<100xf32> + } + return +} + +// Check if the fusion did take place as well as that the source loop was +// not removed. To check if fusion took place, the read instruction from the +// original source loop is checked to be in the fused loop. +// +// PRODUCER-CONSUMER: affine.for %[[idx_0:.*]] = 0 to 100 { +// PRODUCER-CONSUMER-NEXT: %[[result_0:.*]] = affine.load %[[arr1:.*]][%[[idx_0]]] : memref<100xf32> +// PRODUCER-CONSUMER-NEXT: affine.store %[[result_0]], %{{.*}}[%[[idx_0]]] : memref<100xf32> +// PRODUCER-CONSUMER-NEXT: } +// PRODUCER-CONSUMER: affine.for %[[idx_1:.*]] = 0 to 100 step 2 { +// PRODUCER-CONSUMER: affine.load %[[arr1]][%{{.*}}] : memref<100xf32> +// PRODUCER-CONSUMER: } +// PRODUCER-CONSUMER: return + +// ----- + // SIBLING-MAXIMAL-LABEL: func @reduce_add_non_maximal_f32_f32( func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) { %cst_0 = arith.constant 0.000000e+00 : f32