diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1209,6 +1209,12 @@ static bool isReshapableDimBand(unsigned dim, unsigned extent, ArrayRef sizes, ArrayRef strides) { + // Bands of extent one can be reshaped, as they are not reshaped at all. + if (extent == 1) + return true; + // Otherwise, the size of the first dimension needs to be known. + if (ShapedType::isDynamic(sizes[dim])) + return false; assert(sizes.size() == strides.size() && "mismatched ranks"); // off by 1 indexing to avoid out of bounds // V @@ -1217,7 +1223,7 @@ // there is no relation between dynamic sizes and dynamic strides: we do not // have enough information to know whether a "-1" size corresponds to the // proper symbol in the AffineExpr of a stride. - if (ShapedType::isDynamic(sizes[dim + 1])) + if (ShapedType::isDynamic(sizes[idx + 1])) return false; // TODO: Refine this by passing the proper nDims and nSymbols so we can // simplify on the fly and catch more reshapable cases. diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -5,6 +5,7 @@ // CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> // CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> // CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)> +// CHECK-DAG: #[[$strided2D42:.*]] = affine_map<(d0, d1) -> (d0 * 42 + d1)> // CHECK-LABEL: func @memref_reinterpret_cast func @memref_reinterpret_cast(%in: memref) @@ -143,7 +144,8 @@ func @expand_collapse_shape_dynamic(%arg0: memref, %arg1: memref, - %arg2: memref) { + %arg2: memref, + %arg3: memref) { %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref into memref %r0 = memref.expand_shape %0 [[0, 1], [2]] : @@ -160,6 +162,12 @@ %r2 = memref.expand_shape %2 [[0, 1], [2]] : memref into memref + %3 = memref.collapse_shape %arg3 [[0, 1]] : + memref into + memref + %r3 = memref.expand_shape %3 [[0, 1]] : + memref into + memref return } // CHECK-LABEL: func @expand_collapse_shape_dynamic @@ -175,6 +183,10 @@ // CHECK-SAME: memref into memref // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] +// CHECK-SAME: memref into memref +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] +// CHECK-SAME: memref into memref func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref) -> (memref, memref<1x1xf32>) {