diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -752,6 +752,15 @@ Operation *op, ArrayRef operands, Value dataPtr, Value mask); +static std::pair +getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter, Type type) { + auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; + auto llvmType = toLLVMTy(type).template cast(); + auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); + auto align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType()); + return std::make_pair(llvmType, align); +} + template <> void replaceTransferOp(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, @@ -764,10 +773,13 @@ Value fill = rewriter.create(loc, fillType, xferOp.padding()); fill = rewriter.create(loc, toLLVMTy(fillType), fill); - auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); + LLVM::LLVMType vecTy; + unsigned align; + std::tie(vecTy, align) = + getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType()); rewriter.replaceOpWithNewOp( op, vecTy, dataPtr, mask, ValueRange{fill}, - rewriter.getI32IntegerAttr(1)); + rewriter.getI32IntegerAttr(align)); } template <> @@ -777,8 +789,14 @@ ArrayRef operands, Value dataPtr, Value mask) { auto adaptor = TransferWriteOpOperandAdaptor(operands); + + auto xferOp = cast(op); + LLVM::LLVMType vecTy; + unsigned align; + std::tie(vecTy, align) = + getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType()); rewriter.replaceOpWithNewOp( - op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1)); + op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); } static TransferReadOpOperandAdaptor diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -803,7 +803,7 @@ // CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> : // CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>"> // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]], -// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} : +// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} : // CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>"> // @@ -835,7 +835,7 @@ // // 5. Rewrite as a masked write. // CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]] -// CHECK-SAME: {alignment = 1 : i32} : +// CHECK-SAME: {alignment = 128 : i32} : // CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*"> func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: index) -> vector<17xf32> {