diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1162,8 +1162,28 @@ } // Match offset and strides in static_offset and static_strides attributes if - // result memref type has an affine map specified. - if (!resultType.getLayout().isIdentity()) { + // result memref type has an affine map specified. Otherwise expect that the + // layout is the identity. + if (resultType.getLayout().isIdentity()) { + if (!op.hasZeroOffset()) + return op.emitError("expected 0 offset for result type ") + << resultType << " with identity layout"; + + int64_t stride = 1; + for (auto const &en : llvm::reverse( + llvm::zip(resultType.getShape(), + extractFromI64ArrayAttr(op.static_strides())))) { + int64_t currentSize = std::get<0>(en); + int64_t currentStride = std::get<1>(en); + if (ShapedType::isDynamicStrideOrOffset(currentStride) || + ShapedType::isDynamic(currentSize)) + break; + if (stride != currentStride) + return op.emitError("expected stride of ") + << stride << " but found static stride of " << currentStride; + stride *= currentSize; + } + } else { int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -151,7 +151,7 @@ // CHECK: return %[[SIZE]] : index func @dim_of_sized_view(%arg : memref, %size: index) -> index { %c0 = arith.constant 0 : index - %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref to memref + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [1] : memref to memref %1 = memref.dim %0, %c0 : memref return %1 : index } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -208,6 +208,34 @@ // ----- +func @memref_reinterpret_cast_no_map_but_offset(%in: memref) { + // expected-error @+1 {{expected 0 offset for result type 'memref<10xf32>' with identity layout}} + %out = memref.reinterpret_cast %in to offset: [2], sizes: [10], strides: [1] + : memref to memref<10xf32> + return +} + +// ----- + +func @memref_reinterpret_cast_no_map_but_stride(%in: memref) { + // expected-error @+1 {{expected stride of 1 but found static stride of 10}} + %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10] + : memref to memref<10xf32> + return +} + +// ----- + +func @memref_reinterpret_cast_no_map_but_strides(%in: memref) { + // expected-error @+1 {{expected stride of 10 but found static stride of 42}} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [9, 10], strides: [42, 1] + : memref to memref<9x10xf32> + return +} + +// ----- + func @memref_reshape_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{element types of source and destination memref types should be the same}} diff --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir --- a/mlir/test/mlir-cpu-runner/copy.mlir +++ b/mlir/test/mlir-cpu-runner/copy.mlir @@ -35,9 +35,9 @@ // CHECK-NEXT: [3, 4, 5] %copy_two = memref.alloc() : memref<3x2xf32> - %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2,3], strides:[1, 2] - : memref<3x2xf32> to memref<2x3xf32> - memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32> + %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2] + : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> + memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> %unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32> call @print_memref_f32(%unranked_copy_two) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1]