diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1334,6 +1334,7 @@ AffineExpr offset; SmallVector strides; auto status = getStridesAndOffset(type, strides, offset); + auto isIdentityLayout = type.getLayout().isIdentity(); (void)status; assert(succeeded(status) && "expected strided memref"); @@ -1350,12 +1351,19 @@ unsigned dim = m.getNumResults(); int64_t size = 1; AffineExpr stride = strides[currentDim + dim - 1]; - if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { + if (isIdentityLayout || + isReshapableDimBand(currentDim, dim, sizes, strides)) { + for (unsigned d = 0; d < dim; ++d) { + int64_t currentSize = sizes[currentDim + d]; + if (ShapedType::isDynamic(currentSize)) { + size = ShapedType::kDynamicSize; + break; + } + size *= currentSize; + } + } else { size = ShapedType::kDynamicSize; stride = AffineExpr(); - } else { - for (unsigned d = 0; d < dim; ++d) - size *= sizes[currentDim + d]; } newSizes.push_back(size); newStrides.push_back(stride); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -406,6 +406,8 @@ return %collapsed : memref } +// ----- + // CHECK-LABEL: func @collapse_after_memref_cast( // CHECK-SAME: %[[INPUT:.*]]: memref) -> memref { // CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] @@ -419,6 +421,21 @@ // ----- +// CHECK-LABEL: func @collapse_after_memref_cast_type_change_dynamic( +// CHECK-SAME: %[[INPUT:.*]]: memref<1x1x1x?xi64>) -> memref { +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] +// CHECK_SAME: {{\[\[}}0, 1, 2], [3]] : memref<1x1x1x?xi64> into memref<1x?xi64> +// CHECK: %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] : +// CHECK-SAME: memref<1x?xi64> to memref +// CHECK: return %[[DYNAMIC]] : memref +func @collapse_after_memref_cast_type_change_dynamic(%arg0: memref<1x1x1x?xi64>) -> memref { + %casted = memref.cast %arg0 : memref<1x1x1x?xi64> to memref<1x1x?x?xi64> + %collapsed = memref.collapse_shape %casted [[0, 1, 2], [3]] : memref<1x1x?x?xi64> into memref + return %collapsed : memref +} + +// ----- + func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index) -> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> { %c0 = arith.constant 0 : index