diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1141,7 +1141,7 @@ extractFromI64ArrayAttr(op.static_sizes())))) { int64_t resultSize = std::get<0>(en.value()); int64_t expectedSize = std::get<1>(en.value()); - if (resultSize != expectedSize) + if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize) return op.emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize << " in dim = " << en.index(); @@ -1158,7 +1158,8 @@ // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); - if (resultOffset != expectedOffset) + if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && + resultOffset != expectedOffset) return op.emitError("expected result type with offset = ") << resultOffset << " instead of " << expectedOffset; @@ -1167,7 +1168,8 @@ resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { int64_t resultStride = std::get<0>(en.value()); int64_t expectedStride = std::get<1>(en.value()); - if (resultStride != expectedStride) + if (!ShapedType::isDynamicStrideOrOffset(resultStride) && + resultStride != expectedStride) return op.emitError("expected result type with stride = ") << expectedStride << " instead of " << resultStride << " in dim = " << en.index(); diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -208,18 +208,6 @@ // ----- -func @memref_reinterpret_cast_offset_mismatch(%in: memref) { - %c0 = arith.constant 0 : index - %c10 = arith.constant 10 : index - // expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}} - %out = memref.reinterpret_cast %in to - offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] - : memref to memref - return -} - -// ----- - func @memref_reshape_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{element types of source and destination memref types should be the same}} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -17,6 +17,15 @@ return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> } +// CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes +func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref) + -> memref<10x?xf32, offset: ?, strides: [?, 1]> { + %out = memref.reinterpret_cast %in to + offset: [1], sizes: [10, 10], strides: [1, 1] + : memref to memref<10x?xf32, offset: ?, strides: [?, 1]> + return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> +} + // CHECK-LABEL: func @memref_reshape( func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>, %shape2: memref<2xi32>, %shape3: memref) -> memref<*xf32> {