diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1067,13 +1067,23 @@ if (failed(getStridesAndOffset(type, strides, offset))) return false; + // MemRef is contiguous if outer dimensions are size-1 and inner + // dimensions have unit strides. int64_t runningStride = 1; - for (unsigned i = strides.size(); i > 0; --i) { - if (strides[i - 1] != runningStride) - return false; - runningStride *= type.getDimSize(i - 1); + int64_t curDim = strides.size() - 1; + // Finds all inner dimensions with unit strides. + while (curDim >= 0 && strides[curDim] == runningStride) { + runningStride *= type.getDimSize(curDim); + --curDim; } - return true; + + // Check if other dimensions are size-1. + while (curDim >= 0 && type.getDimSize(curDim) == 1) { + --curDim; + } + + // All dims are unit-strided or size-1. + return curDim < 0; }; auto isContiguousMemrefType = [&](BaseMemRefType type) { diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -455,15 +455,15 @@ // ----- // CHECK-LABEL: func @memref_copy_contiguous -func.func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) { +func.func @memref_copy_contiguous(%in: memref<16x4xi32>, %offset: index) { %buf = memref.alloc() : memref<1x2xi32> - %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, strided<[2, 1], offset: ?>> - memref.copy %sub, %buf : memref<1x2xi32, strided<[2, 1], offset: ?>> to memref<1x2xi32> + %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x4xi32> to memref<1x2xi32, strided<[4, 1], offset: ?>> + memref.copy %sub, %buf : memref<1x2xi32, strided<[4, 1], offset: ?>> to memref<1x2xi32> // Skip the memref descriptor of the alloc. // CHECK: llvm.insertvalue {{%.*}}, {{%.*}}[4, 1] // Get the memref for the subview. - // CHECK: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%{{.*}}, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, strided<[2, 1], offset: ?>> - // CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[SUBVIEW]] : memref<1x2xi32, strided<[2, 1], offset: ?>> to !llvm.struct<(ptr + // CHECK: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%{{.*}}, 0] [1, 2] [1, 1] : memref<16x4xi32> to memref<1x2xi32, strided<[4, 1], offset: ?>> + // CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[SUBVIEW]] : memref<1x2xi32, strided<[4, 1], offset: ?>> to !llvm.struct<(ptr // CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue %[[DESC]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[MUL1:%.*]] = llvm.mul {{.*}}, [[EXTRACT0]] : i64 // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue %[[DESC]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>