diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -921,21 +921,17 @@ // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector remapExprs; remapExprs.reserve(rank); - unsigned zeroOffsetCount = 0; for (unsigned i = 0; i < rank; i++) { - if (auto constExpr = offsets[i].dyn_cast()) - if (constExpr.getValue() == 0) - ++zeroOffsetCount; auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); auto remapExpr = simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); remapExprs.push_back(remapExpr); } - auto indexRemap = zeroOffsetCount == rank - ? AffineMap() - : AffineMap::get(outerIVs.size() + rank, 0, remapExprs, - forOp.getContext()); + + auto indexRemap = + AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); + // Replace all users of 'oldMemRef' with 'newMemRef'. LogicalResult res = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2634,3 +2634,32 @@ // CHECK: affine.for // CHECK: mulf // CHECK: subf + +// ----- + +// MAXIMAL-LABEL: func @fuse_minor_affine_map +func @fuse_minor_affine_map(%in: memref<128xf32>, %out: memref<20x512xf32>) { + %tmp = alloc() : memref<128xf32> + + affine.for %arg4 = 0 to 128 { + %ld = affine.load %in[%arg4] : memref<128xf32> + affine.store %ld, %tmp[%arg4] : memref<128xf32> + } + + affine.for %arg3 = 0 to 20 { + affine.for %arg4 = 0 to 512 { + %ld = affine.load %tmp[%arg4 mod 128] : memref<128xf32> + affine.store %ld, %out[%arg3, %arg4] : memref<20x512xf32> + } + } + + return +} + +// TODO: The size of the private memref is not properly computed in the presence +// of the 'mod' operation. It should be memref<1xf32> instead of +// memref<128xf32>: https://bugs.llvm.org/show_bug.cgi?id=46973 +// MAXIMAL: alloc() : memref<128xf32> +// MAXIMAL: affine.for +// MAXIMAL-NEXT: affine.for +// MAXIMAL-NOT: affine.for