diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -119,6 +119,18 @@ return success(); } +// Helper to find assume_alignment information. +void updateAlignment(Value value, unsigned &align) { + for (auto &u : value.getUses()) { + Operation *owner = u.getOwner(); + if(auto op = dyn_cast(owner)) { + unsigned newAlignment = op.alignment(); + if(newAlignment > align) + align = newAlignment; + } + } +} + // Add an index vector component to a base pointer. This almost always succeeds // unless the last stride is non-unit or the memory space is not zero. static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, @@ -154,6 +166,7 @@ if (failed(getMemRefAlignment( typeConverter, xferOp.getShapedType().cast(), align))) return failure(); + updateAlignment(xferOp.source(), align); rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); return success(); } @@ -174,7 +187,7 @@ if (failed(getMemRefAlignment( typeConverter, xferOp.getShapedType().cast(), align))) return failure(); - + updateAlignment(xferOp.source(), align); rewriter.replaceOpWithNewOp( xferOp, vecTy, dataPtr, mask, ValueRange{fill}, rewriter.getI32IntegerAttr(align)); @@ -190,6 +203,7 @@ if (failed(getMemRefAlignment( typeConverter, xferOp.getShapedType().cast(), align))) return failure(); + updateAlignment(xferOp.source(), align); auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, align); @@ -207,6 +221,7 @@ return failure(); auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); + updateAlignment(xferOp.source(), align); rewriter.replaceOpWithNewOp( xferOp, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1295,6 +1295,26 @@ // ----- +func @transfer_read_1d_aligned(%A : memref, %base: index) -> vector<17xf32> { + memref.assume_alignment %A, 32 : memref + %f7 = constant 7.0: f32 + %f = vector.transfer_read %A[%base], %f7 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref, vector<17xf32> + vector.transfer_write %f, %A[%base] + {permutation_map = affine_map<(d0) -> (d0)>} : + vector<17xf32>, memref + return %f: vector<17xf32> +} +// CHECK: llvm.intr.masked.load +// CHECK-SAME: {alignment = 32 : i32} +// CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xf32>) -> vector<17xf32> +// CHECK: llvm.intr.masked.store +// CHECK-SAME: {alignment = 32 : i32} +// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr> + +// ----- + func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: index) -> vector<17xf32> { %f7 = constant 7.0: f32 %f = vector.transfer_read %A[%base0, %base1], %f7