diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1298,7 +1298,8 @@ /// copies. static bool isReshapableDimBand(unsigned dim, unsigned extent, ArrayRef sizes, - ArrayRef strides) { + ArrayRef strides, + bool isIdentityLayout) { // Bands of extent one can be reshaped, as they are not reshaped at all. if (extent == 1) return true; @@ -1316,8 +1317,10 @@ if (ShapedType::isDynamic(sizes[idx + 1])) return false; // TODO: Refine this by passing the proper nDims and nSymbols so we can - // simplify on the fly and catch more reshapable cases. - if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) + // simplify on the fly and catch more reshapable cases. For now, if we + // know it is an identity layout, ignore strides to support dynamic cases. + if (!isIdentityLayout && + (strides[idx] != strides[idx + 1] * sizes[idx + 1])) return false; } return true; @@ -1334,6 +1337,7 @@ AffineExpr offset; SmallVector strides; auto status = getStridesAndOffset(type, strides, offset); + auto isIdentityLayout = type.getLayout().isIdentity(); (void)status; assert(succeeded(status) && "expected strided memref"); @@ -1350,7 +1354,8 @@ unsigned dim = m.getNumResults(); int64_t size = 1; AffineExpr stride = strides[currentDim + dim - 1]; - if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { + if (!isReshapableDimBand(currentDim, dim, sizes, strides, + isIdentityLayout)) { size = ShapedType::kDynamicSize; stride = AffineExpr(); } else { diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -406,6 +406,8 @@ return %collapsed : memref } +// ----- + // CHECK-LABEL: func @collapse_after_memref_cast( // CHECK-SAME: %[[INPUT:.*]]: memref) -> memref { // CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] @@ -419,6 +421,21 @@ // ----- +// CHECK-LABEL: func @collapse_after_memref_cast_type_change_dynamic( +// CHECK-SAME: %[[INPUT:.*]]: memref<1x1x1x?xi64>) -> memref { +// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]] +// CHECK_SAME: {{\[\[}}0, 1, 2], [3]] : memref<1x1x1x?xi64> into memref<1x?xi64> +// CHECK: %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] : +// CHECK-SAME: memref<1x?xi64> to memref +// CHECK: return %[[DYNAMIC]] : memref +func @collapse_after_memref_cast_type_change_dynamic(%arg0: memref<1x1x1x?xi64>) -> memref { + %casted = memref.cast %arg0 : memref<1x1x1x?xi64> to memref<1x1x?x?xi64> + %collapsed = memref.collapse_shape %casted [[0, 1, 2], [3]] : memref<1x1x?x?xi64> into memref + return %collapsed : memref +} + +// ----- + func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index) -> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> { %c0 = arith.constant 0 : index