diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" #include "mlir/Transforms/DialectConversion.h" @@ -119,6 +120,42 @@ return success(); } +// Return the minimal alignment value that satisfies all the AssumeAlignment +// uses of `value`. If no such uses exist, return 1. +static unsigned getAssumedAlignment(Value value) { + unsigned align = 1; + for (auto &u : value.getUses()) { + Operation *owner = u.getOwner(); + if (auto op = dyn_cast(owner)) + align = mlir::lcm(align, op.alignment()); + } + return align; +} +// Helper that returns data layout alignment of a memref associated with a +// transfer op, including additional information from assume_alignment calls +// on the source of the transfer +LogicalResult getTransferOpAlignment(LLVMTypeConverter &typeConverter, + VectorTransferOpInterface xfer, + unsigned &align) { + if (failed(getMemRefAlignment( + typeConverter, xfer.getShapedType().cast(), align))) + return failure(); + align = std::max(align, getAssumedAlignment(xfer.source())); + return success(); +} + +// Helper that returns data layout alignment of a memref associated with a +// load, store, scatter, or gather op, including additional information from +// assume_alignment calls on the source of the transfer +template +LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter, + OpAdaptor op, unsigned &align) { + if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align))) + return failure(); + align = std::max(align, getAssumedAlignment(op.base())); + return success(); +} + // Add an index vector component to a base pointer. This almost always succeeds // unless the last stride is non-unit or the memory space is not zero. static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, @@ -151,8 +188,7 @@ TransferReadOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; - if (failed(getMemRefAlignment( - typeConverter, xferOp.getShapedType().cast(), align))) + if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) return failure(); rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); return success(); @@ -171,10 +207,8 @@ Value fill = rewriter.create(loc, vecTy, adaptor.padding()); unsigned align; - if (failed(getMemRefAlignment( - typeConverter, xferOp.getShapedType().cast(), align))) + if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) return failure(); - rewriter.replaceOpWithNewOp( xferOp, vecTy, dataPtr, mask, ValueRange{fill}, rewriter.getI32IntegerAttr(align)); @@ -187,8 +221,7 @@ TransferWriteOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; - if (failed(getMemRefAlignment( - typeConverter, xferOp.getShapedType().cast(), align))) + if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, @@ -202,8 +235,7 @@ TransferWriteOp xferOp, ArrayRef operands, Value dataPtr, Value mask) { unsigned align; - if (failed(getMemRefAlignment( - typeConverter, xferOp.getShapedType().cast(), align))) + if (failed(getTransferOpAlignment(typeConverter, xferOp, align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); @@ -337,7 +369,8 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) + if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp, + align))) return failure(); // Resolve address. @@ -367,7 +400,7 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) + if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align))) return failure(); // Resolve address. @@ -402,7 +435,7 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) + if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align))) return failure(); // Resolve address. diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1295,6 +1295,26 @@ // ----- +func @transfer_read_1d_aligned(%A : memref, %base: index) -> vector<17xf32> { + memref.assume_alignment %A, 32 : memref + %f7 = constant 7.0: f32 + %f = vector.transfer_read %A[%base], %f7 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref, vector<17xf32> + vector.transfer_write %f, %A[%base] + {permutation_map = affine_map<(d0) -> (d0)>} : + vector<17xf32>, memref + return %f: vector<17xf32> +} +// CHECK: llvm.intr.masked.load +// CHECK-SAME: {alignment = 32 : i32} +// CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xf32>) -> vector<17xf32> +// CHECK: llvm.intr.masked.store +// CHECK-SAME: {alignment = 32 : i32} +// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr> + +// ----- + func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: index) -> vector<17xf32> { %f7 = constant 7.0: f32 %f = vector.transfer_read %A[%base0, %base1], %f7 @@ -1487,6 +1507,22 @@ // ----- +func @vector_load_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { + memref.assume_alignment %memref, 32 : memref<200x100xf32> + %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @vector_load_op_aligned +// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64 +// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64 +// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr to !llvm.ptr> +// CHECK: llvm.load %[[bcast]] {alignment = 32 : i64} : !llvm.ptr> + +// ----- + func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) { %val = constant dense<11.0> : vector<4xf32> vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32> @@ -1513,6 +1549,23 @@ // ----- +func @vector_store_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) { + memref.assume_alignment %memref, 32 : memref<200x100xf32> + %val = constant dense<11.0> : vector<4xf32> + vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32> + return +} + +// CHECK-LABEL: func @vector_store_op_aligned +// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64 +// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64 +// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr to !llvm.ptr> +// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 32 : i64} : !llvm.ptr> + +// ----- + func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { %c0 = constant 0: index %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> @@ -1590,6 +1643,20 @@ // ----- +func @gather_op_aligned(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { + memref.assume_alignment %arg0, 32 : memref + %0 = constant 0: index + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> + return %1 : vector<3xf32> +} + +// CHECK-LABEL: func @gather_op_aligned +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> +// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> +// CHECK: return %[[G]] : vector<3xf32> + +// ----- + func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> { %0 = constant 3 : index %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> @@ -1628,6 +1695,19 @@ // ----- +func @scatter_op_aligned(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { + memref.assume_alignment %arg0, 32 : memref + %0 = constant 0: index + vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> + return +} + +// CHECK-LABEL: func @scatter_op_aligned +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> +// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr> + +// ----- + func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) { %0 = constant 3 : index vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>