diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -879,10 +879,9 @@ auto sourcePtr = promote(unrankedSource); auto targetPtr = promote(unrankedTarget); - unsigned typeSize = - mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType()); - auto elemSize = rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(typeSize)); + // Derive size from llvm.getelementptr which will account for any + // potential alignment + auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( op->getParentOfType(), getIndexType(), sourcePtr.getType()); rewriter.create(loc, copyFn, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -558,7 +558,8 @@ // CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.struct<(i64, ptr)>, !llvm.ptr // CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr // CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.struct<(i64, ptr)>, !llvm.ptr - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: [[SIZEPTR:%.*]] = llvm.getelementptr {{%.*}}[1] : (!llvm.ptr) -> !llvm.ptr, i1 + // CHECK: [[SIZE:%.*]] = llvm.ptrtoint [[SIZEPTR]] : !llvm.ptr to i64 // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr, !llvm.ptr) -> () // CHECK: llvm.intr.stackrestore [[STACKSAVE]] return diff --git a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir --- a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir @@ -82,7 +82,8 @@ // CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.ptr)>> // CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr)>> // CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.ptr)>> - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: [[SIZEPTR:%.*]] = llvm.getelementptr {{%.*}}[1] : (!llvm.ptr) -> !llvm.ptr + // CHECK: [[SIZE:%.*]] = llvm.ptrtoint [[SIZEPTR]] : !llvm.ptr to i64 // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr)>>, !llvm.ptr)>>) -> () // CHECK: llvm.intr.stackrestore [[STACKSAVE]] return