diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1227,7 +1227,9 @@ return false; // TODO: Refine this by passing the proper nDims and nSymbols so we can // simplify on the fly and catch more reshapable cases. - if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) + // The size!=1 condition here means that we allow any stride for unit dims, + // as strides of unit dims should not make a practical difference. + if (sizes[idx] != 1 && strides[idx] != strides[idx + 1] * sizes[idx + 1]) return false; } return true; diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -207,3 +207,12 @@ // CHECK: func @collapse_shape_to_dynamic // CHECK: memref.collapse_shape // CHECK-SAME: [0], [1], [2, 3, 4] + +func @collapse_static_shape_with_unit_dims_with_dynamic_strides(%arg: memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2 + d2 * 8 + d3)>>) -> memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> { + %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> +} + +// CHECK: func @collapse_static_shape_with_unit_dims_with_dynamic_strides +// CHECK: memref.collapse_shape +// CHECK-SAME: into memref<64xf32,