diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -752,6 +752,19 @@ Operation *op, ArrayRef operands, Value dataPtr, Value mask); +LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter, + Type type, LLVM::LLVMType &llvmType, + unsigned &align) { + auto convertedType = typeConverter.convertType(type); + if (!convertedType) + return failure(); + + llvmType = convertedType.template cast(); + auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); + align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType()); + return success(); +} + template <> void replaceTransferOp(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, @@ -764,10 +777,13 @@ Value fill = rewriter.create(loc, fillType, xferOp.padding()); fill = rewriter.create(loc, toLLVMTy(fillType), fill); - auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); - rewriter.replaceOpWithNewOp( - op, vecTy, dataPtr, mask, ValueRange{fill}, - rewriter.getI32IntegerAttr(1)); + LLVM::LLVMType vecTy; + unsigned align; + if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), + vecTy, align))) + rewriter.replaceOpWithNewOp( + op, vecTy, dataPtr, mask, ValueRange{fill}, + rewriter.getI32IntegerAttr(align)); } template <> @@ -777,8 +793,14 @@ ArrayRef operands, Value dataPtr, Value mask) { auto adaptor = TransferWriteOpOperandAdaptor(operands); - rewriter.replaceOpWithNewOp( - op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1)); + + auto xferOp = cast(op); + LLVM::LLVMType vecTy; + unsigned align; + if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), + vecTy, align))) + rewriter.replaceOpWithNewOp( + op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); } static TransferReadOpOperandAdaptor diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -818,7 +818,7 @@ // CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> : // CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>"> // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]], -// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} : +// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} : // CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>"> // @@ -850,7 +850,7 @@ // // 5. Rewrite as a masked write. // CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]] -// CHECK-SAME: {alignment = 1 : i32} : +// CHECK-SAME: {alignment = 128 : i32} : // CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*"> func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: index) -> vector<17xf32> {