diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1848,6 +1848,12 @@ return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); assert(offset == 0 && "expected offset to be 0"); + // Target memref must be contiguous in memory (innermost stride is 1), or + // empty (special case when at least one of the memref dimensions is 0). + if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) + return viewOp.emitWarning("cannot cast to non-contiguous shape"), + failure(); + // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.source()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); @@ -1884,9 +1890,6 @@ return rewriter.replaceOp(viewOp, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. - if (strides.back() != 1) - return viewOp.emitWarning("cannot cast to non-contiguous shape"), - failure(); Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -89,6 +89,29 @@ // ----- +// CHECK-LABL: func @view_empty_memref( +// CHECK: %[[ARG0:.*]]: index, +// CHECK: %[[ARG1:.*]]: memref<0xi8>) +func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) { + + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(0 : index) : i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(4 : index) : i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(0 : index) : i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(0 : index) : i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.mlir.constant(0 : index) : i64 + // CHECK: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %0 = memref.view %mem[%offset][] : memref<0xi8> to memref<0x4xf32> + + return +} + +// ----- + // CHECK-LABEL: func @subview( // CHECK: %[[MEM:.*]]: memref<{{.*}}>, // CHECK: %[[ARG0f:[a-zA-Z0-9]*]]: index,