diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1155,40 +1155,44 @@ extractFromI64ArrayAttr(op.static_sizes())))) { int64_t resultSize = std::get<0>(en.value()); int64_t expectedSize = std::get<1>(en.value()); - if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize) + if (!ShapedType::isDynamic(resultSize) && + !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) return op.emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize << " in dim = " << en.index(); } - // Match offset and strides in static_offset and static_strides attributes if - // result memref type has an affine map specified. - if (!resultType.getLayout().isIdentity()) { - int64_t resultOffset; - SmallVector resultStrides; - if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) - return failure(); - - // Match offset in result memref type and in static_offsets attribute. - int64_t expectedOffset = - extractFromI64ArrayAttr(op.static_offsets()).front(); - if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && - resultOffset != expectedOffset) - return op.emitError("expected result type with offset = ") - << resultOffset << " instead of " << expectedOffset; - - // Match strides in result memref type and in static_strides attribute. - for (auto &en : llvm::enumerate(llvm::zip( - resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { - int64_t resultStride = std::get<0>(en.value()); - int64_t expectedStride = std::get<1>(en.value()); - if (!ShapedType::isDynamicStrideOrOffset(resultStride) && - resultStride != expectedStride) - return op.emitError("expected result type with stride = ") - << expectedStride << " instead of " << resultStride - << " in dim = " << en.index(); - } + // Match offset and strides in static_offset and static_strides attributes. If + // result memref type has no affine map specified, this will assume an + // identity layout. + int64_t resultOffset; + SmallVector resultStrides; + if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + return op.emitError( + "expected result type to have strided layout but found ") + << resultType; + + // Match offset in result memref type and in static_offsets attribute. + int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); + if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && + !ShapedType::isDynamicStrideOrOffset(expectedOffset) && + resultOffset != expectedOffset) + return op.emitError("expected result type with offset = ") + << resultOffset << " instead of " << expectedOffset; + + // Match strides in result memref type and in static_strides attribute. + for (auto &en : llvm::enumerate(llvm::zip( + resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { + int64_t resultStride = std::get<0>(en.value()); + int64_t expectedStride = std::get<1>(en.value()); + if (!ShapedType::isDynamicStrideOrOffset(resultStride) && + !ShapedType::isDynamicStrideOrOffset(expectedStride) && + resultStride != expectedStride) + return op.emitError("expected result type with stride = ") + << expectedStride << " instead of " << resultStride + << " in dim = " << en.index(); } + return success(); } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -151,7 +151,7 @@ // CHECK: return %[[SIZE]] : index func @dim_of_sized_view(%arg : memref, %size: index) -> index { %c0 = arith.constant 0 : index - %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref to memref + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref to memref %1 = memref.dim %0, %c0 : memref return %1 : index } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -208,6 +208,44 @@ // ----- +func @memref_reinterpret_cast_no_map_but_offset(%in: memref) { + // expected-error @+1 {{expected result type with offset = 0 instead of 2}} + %out = memref.reinterpret_cast %in to offset: [2], sizes: [10], strides: [1] + : memref to memref<10xf32> + return +} + +// ----- + +func @memref_reinterpret_cast_no_map_but_stride(%in: memref) { + // expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}} + %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10] + : memref to memref<10xf32> + return +} + +// ----- + +func @memref_reinterpret_cast_no_map_but_strides(%in: memref) { + // expected-error @+1 {{expected result type with stride = 42 instead of 10 in dim = 0}} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [9, 10], strides: [42, 1] + : memref to memref<9x10xf32> + return +} + +// ----- + +func @memref_reinterpret_cast_non_strided_layout(%in: memref) { + // expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [9, 10], strides: [42, 1] + : memref to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>> + return +} + +// ----- + func @memref_reshape_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{element types of source and destination memref types should be the same}} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -27,6 +27,15 @@ return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> } +// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset +func @memref_reinterpret_cast_dynamic_offset(%in: memref, %offset: index) + -> memref<10x?xf32, offset: ?, strides: [?, 1]> { + %out = memref.reinterpret_cast %in to + offset: [%offset], sizes: [10, 10], strides: [1, 1] + : memref to memref<10x?xf32, offset: ?, strides: [?, 1]> + return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> +} + // CHECK-LABEL: func @memref_reshape( func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>, %shape2: memref<2xi32>, %shape3: memref) -> memref<*xf32> { diff --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir --- a/mlir/test/mlir-cpu-runner/copy.mlir +++ b/mlir/test/mlir-cpu-runner/copy.mlir @@ -35,9 +35,9 @@ // CHECK-NEXT: [3, 4, 5] %copy_two = memref.alloc() : memref<3x2xf32> - %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2,3], strides:[1, 2] - : memref<3x2xf32> to memref<2x3xf32> - memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32> + %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2] + : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> + memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> %unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32> call @print_memref_f32(%unranked_copy_two) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1]