diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -56,6 +56,11 @@ /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; + VectorTransformsOptions & + setVectorTransformsOptions(VectorContractLowering opt) { + vectorContractLowering = opt; + return *this; + } }; /// Collect a set of transformation patterns that are related to contracting diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -767,16 +767,17 @@ } }; -LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter, - Type type, LLVM::LLVMType &llvmType, - unsigned &align) { - auto convertedType = typeConverter.convertType(type); - if (!convertedType) +template +LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter, + TransferOp xferOp, unsigned &align) { + Type elementTy = + typeConverter.convertType(xferOp.getMemRefType().getElementType()); + if (!elementTy) return failure(); - llvmType = convertedType.template cast(); auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); - align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType()); + align = dataLayout.getPrefTypeAlignment( + elementTy.cast().getUnderlyingType()); return success(); } @@ -785,11 +786,6 @@ LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, ArrayRef operands, Value dataPtr) { - LLVM::LLVMType vecTy; - unsigned align; - if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), - vecTy, align))) - return failure(); rewriter.replaceOpWithNewOp(xferOp, dataPtr); return success(); } @@ -804,10 +800,12 @@ Value fill = rewriter.create(loc, fillType, xferOp.padding()); fill = rewriter.create(loc, toLLVMTy(fillType), fill); - LLVM::LLVMType vecTy; + Type vecTy = typeConverter.convertType(xferOp.getVectorType()); + if (!vecTy) + return failure(); + unsigned align; - if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), - vecTy, align))) + if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) return failure(); rewriter.replaceOpWithNewOp( @@ -822,11 +820,6 @@ TransferWriteOp xferOp, ArrayRef operands, Value dataPtr) { auto adaptor = TransferWriteOpOperandAdaptor(operands); - LLVM::LLVMType vecTy; - unsigned align; - if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), - vecTy, align))) - return failure(); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr); return success(); } @@ -836,13 +829,11 @@ Location loc, TransferWriteOp xferOp, ArrayRef operands, Value dataPtr, Value mask) { - auto adaptor = TransferWriteOpOperandAdaptor(operands); - LLVM::LLVMType vecTy; unsigned align; - if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(), - vecTy, align))) + if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) return failure(); + auto adaptor = TransferWriteOpOperandAdaptor(operands); rewriter.replaceOpWithNewOp( xferOp, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -818,7 +818,7 @@ // CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> : // CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>"> // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]], -// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} : +// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} : // CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>"> // @@ -850,7 +850,7 @@ // // 5. Rewrite as a masked write. // CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]] -// CHECK-SAME: {alignment = 128 : i32} : +// CHECK-SAME: {alignment = 4 : i32} : // CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*"> func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: index) -> vector<17xf32> {