diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -911,22 +911,14 @@ } auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), {}, newMemSpace); - // Gather alloc operands for the dynamic dimensions of the memref. - SmallVector allocOperands; - unsigned dynamicDimCount = 0; - for (auto dimSize : oldMemRefType.getShape()) { - if (dimSize == -1) - allocOperands.push_back( - top.create(forOp.getLoc(), oldMemRef, dynamicDimCount++)); - } - // Create new private memref for fused loop 'forOp'. + // Create new private memref for fused loop 'forOp'. 'newShape' is always + // a constant shape. // TODO(andydavis) Create/move alloc ops for private memrefs closer to their // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the function, because loop nests can be reordered // during the fusion pass. - Value newMemRef = - top.create(forOp.getLoc(), newMemRefType, allocOperands); + Value newMemRef = top.create(forOp.getLoc(), newMemRefType); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector remapExprs; diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2535,3 +2535,38 @@ // CHECK: mulf // CHECK-NOT: affine.for // CHECK: divf + +// ----- + +// Test fusion when dynamically shaped memrefs are used with constant trip count loops. + +// CHECK-LABEL: func @calc +func @calc(%arg0: memref, %arg1: memref, %arg2: memref, %len: index) { + %c1 = constant 1 : index + %1 = alloc(%len) : memref + affine.for %arg4 = 1 to 10 { + %7 = affine.load %arg0[%arg4] : memref + %8 = affine.load %arg1[%arg4] : memref + %9 = addf %7, %8 : f32 + affine.store %9, %1[%arg4] : memref + } + affine.for %arg4 = 1 to 10 { + %7 = affine.load %1[%arg4] : memref + %8 = affine.load %arg1[%arg4] : memref + %9 = mulf %7, %8 : f32 + affine.store %9, %arg2[%arg4] : memref + } + return +} +// CHECK: alloc() : memref<1xf32> +// CHECK: affine.for %arg{{.*}} = 1 to 10 { +// CHECK-NEXT: affine.load %arg{{.*}} +// CHECK-NEXT: affine.load %arg{{.*}} +// CHECK-NEXT: addf +// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> +// CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32> +// CHECK-NEXT: affine.load %arg{{.*}}[%arg{{.*}}] : memref +// CHECK-NEXT: mulf +// CHECK-NEXT: affine.store %{{.*}}, %arg{{.*}}[%arg{{.*}}] : memref +// CHECK-NEXT: } +// CHECK-NEXT: return