diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -960,10 +960,21 @@ if (originalType.getRank() == reducedType.getRank()) return unusedDims; - for (const auto &dim : llvm::enumerate(sizes)) - if (auto attr = llvm::dyn_cast_if_present(dim.value())) - if (llvm::cast(attr).getInt() == 1) - unusedDims.set(dim.index()); + ArrayRef reducedShape = reducedType.getShape(); + size_t reducedShapePos = reducedShape.size(); + for (size_t ri = 0, re = sizes.size(); ri < re; ++ri) { + size_t index = sizes.size() - 1 - ri; + OpFoldResult dim = sizes[index]; + if (auto attr = llvm::dyn_cast_if_present(dim)) { + if (llvm::cast(attr).getInt() == 1) { + if (!(reducedShapePos > 0 && reducedShape[reducedShapePos - 1] == 1)) { + unusedDims.set(index); + continue; + } + } + } + reducedShapePos--; + } // Early exit for the case where the number of unused dims matches the number // of ranks reduced. diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -955,3 +955,16 @@ // CHECK: return %[[cast]] return %0 : memref> } + +// ----- + +func.func @keep_preserved_unit_dimensions(%arg0: tensor, %arg1: index) -> index { + %0 = bufferization.to_memref %arg0 : memref> + %c1 = arith.constant 1 : index + %subview = memref.subview %0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] : memref> to memref<1x?xf32, strided<[?, ?], offset: ?>> + %dim = memref.dim %subview, %c1 : memref<1x?xf32, strided<[?, ?], offset: ?>> + return %dim : index +} +// CHECK-LABEL: func @keep_preserved_unit_dimensions +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK: return %[[ARG1]] diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -173,8 +173,8 @@ // CHECK-DAG: %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]] // CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG14]], %[[ARG9]]] // CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]] -// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]] -// CHECK: memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[I4]], %[[ARG6]]] +// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG16]], %[[ARG12]]] +// CHECK: memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[ARG5]], %[[I4]]] // ----- @@ -588,9 +588,9 @@ // CHECK: func.func @fold_src_nvgpu_device_async_copy // CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[SRC_IDX_0:.+]]: index, %[[SRC_IDX_1:.+]]: index, %[[SRC_IDX_2:.+]]: index, %[[SRC_SUB_IDX_0:.+]]: index, %[[SRC_SUB_IDX_1:.+]]: index) // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_0]], %[[SRC_SUB_IDX_0]]] +// CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_1]], %[[SRC_SUB_IDX_0]]] // CHECK-DAG: %[[RESOLVED_SRC_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_2]], %[[SRC_SUB_IDX_1]]] -// CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> +// CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[SRC_IDX_0]], %[[RESOLVED_SRC_IDX_0]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> // ----- @@ -607,11 +607,11 @@ // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> // CHECK: func.func @fold_src_fold_dest_nvgpu_device_async_copy // CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[SRC_IDX_0:.+]]: index, %[[SRC_IDX_1:.+]]: index, %[[SRC_IDX_2:.+]]: index, %[[SRC_SUB_IDX_0:.+]]: index, %[[SRC_SUB_IDX_1:.+]]: index, %[[DEST_IDX_0:.+]]: index, %[[DEST_IDX_1:.+]]: index, %[[DEST_IDX_2:.+]]: index, %[[DEST_IDX_3:.+]]: index, %[[DEST_SUB_IDX_0:.+]]: index, %[[DEST_SUB_IDX_1:.+]]: index) -// CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_0]], %[[SRC_SUB_IDX_0]]] +// CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_1]], %[[SRC_SUB_IDX_0]]] // CHECK-DAG: %[[RESOLVED_SRC_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_2]], %[[SRC_SUB_IDX_1]]] -// CHECK-DAG: %[[RESOLVED_DST_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_1]], %[[DEST_SUB_IDX_0]]] +// CHECK-DAG: %[[RESOLVED_DST_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_2]], %[[DEST_SUB_IDX_0]]] // CHECK-DAG: %[[RESOLVED_DST_IDX_3:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_3]], %[[DEST_SUB_IDX_1]]] -// CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[DEST_IDX_0]], %[[RESOLVED_DST_IDX_1]], %[[DEST_IDX_2]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> +// CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[SRC_IDX_0]], %[[RESOLVED_SRC_IDX_0]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[DEST_IDX_0]], %[[DEST_IDX_1]], %[[RESOLVED_DST_IDX_1]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> // -----