diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Support/Alignment.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -732,20 +733,6 @@ return Changed; } - LoadInst *createVectorLoad(Value *ColumnPtr, Type *EltType, bool IsVolatile, - IRBuilder<> &Builder) { - return Builder.CreateAlignedLoad(ColumnPtr, - Align(DL.getABITypeAlignment(EltType)), - IsVolatile, "col.load"); - } - - StoreInst *createVectorStore(Value *ColumnValue, Value *ColumnPtr, - Type *EltType, bool IsVolatile, - IRBuilder<> &Builder) { - return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, - DL.getABITypeAlign(EltType), IsVolatile); - } - /// Turns \p BasePtr into an elementwise pointer to \p EltType. Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { unsigned AS = cast(BasePtr->getType())->getAddressSpace(); @@ -777,10 +764,30 @@ return true; } + /// Compute the alignment for a column/row \p Idx with \p Stride between them. + /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a + /// ConstantInt, reduce the initial alignment based on the byte offset. For + /// non-ConstantInt strides, return the common alignment of the initial + /// alignment and the element size in bytes. + Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, + MaybeAlign A) const { + Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); + if (Idx == 0) + return InitialAlign; + + TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); + if (auto *ConstStride = dyn_cast(Stride)) { + uint64_t StrideInBytes = + ConstStride->getZExtValue() * ElementSizeInBits / 8; + return commonAlignment(InitialAlign, Idx * StrideInBytes); + } + return commonAlignment(InitialAlign, ElementSizeInBits / 8); + } + /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between /// vectors. - MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, bool IsVolatile, - ShapeInfo Shape, IRBuilder<> &Builder) { + MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, + bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { auto VType = cast(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); MatrixTy Result; @@ -788,8 +795,10 @@ Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, Shape.getStride(), VType->getElementType(), Builder); - Value *Vector = - createVectorLoad(GEP, VType->getElementType(), IsVolatile, Builder); + Value *Vector = Builder.CreateAlignedLoad( + GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), + IsVolatile, "col.load"); + Result.addVector(Vector); } return Result.addNumLoads(getNumOps(Result.getVectorTy()) * @@ -798,8 +807,9 @@ /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, /// starting at \p MatrixPtr[I][J]. - MatrixTy loadMatrix(Value *MatrixPtr, bool IsVolatile, ShapeInfo MatrixShape, - Value *I, Value *J, ShapeInfo ResultShape, Type *EltTy, + MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, + ShapeInfo MatrixShape, Value *I, Value *J, + ShapeInfo ResultShape, Type *EltTy, IRBuilder<> &Builder) { Value *Offset = Builder.CreateAdd( @@ -815,19 +825,19 @@ Value *TilePtr = Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); - return loadMatrix(TileTy, TilePtr, + return loadMatrix(TileTy, TilePtr, Align, Builder.getInt64(MatrixShape.getStride()), IsVolatile, ResultShape, Builder); } /// Lower a load instruction with shape information. - void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, bool IsVolatile, - ShapeInfo Shape) { + void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, + bool IsVolatile, ShapeInfo Shape) { IRBuilder<> Builder(Inst); - finalizeLowering( - Inst, - loadMatrix(Inst->getType(), Ptr, Stride, IsVolatile, Shape, Builder), - Builder); + finalizeLowering(Inst, + loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, + Shape, Builder), + Builder); } /// Lowers llvm.matrix.column.major.load. @@ -838,16 +848,16 @@ "Intrinsic only supports column-major layout!"); Value *Ptr = Inst->getArgOperand(0); Value *Stride = Inst->getArgOperand(1); - LowerLoad(Inst, Ptr, Stride, + LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, cast(Inst->getArgOperand(2))->isOne(), {Inst->getArgOperand(3), Inst->getArgOperand(4)}); } /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p /// MatrixPtr[I][J]. - void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, bool IsVolatile, - ShapeInfo MatrixShape, Value *I, Value *J, Type *EltTy, - IRBuilder<> &Builder) { + void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, + MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, + Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { Value *Offset = Builder.CreateAdd( Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); @@ -861,34 +871,38 @@ Value *TilePtr = Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); - storeMatrix(TileTy, StoreVal, TilePtr, + storeMatrix(TileTy, StoreVal, TilePtr, MAlign, Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); } /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between /// vectors. - MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride, - bool IsVolatile, IRBuilder<> &Builder) { + MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, + MaybeAlign MAlign, Value *Stride, bool IsVolatile, + IRBuilder<> &Builder) { auto VType = cast(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); for (auto Vec : enumerate(StoreVal.vectors())) { Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), Stride, StoreVal.getStride(), VType->getElementType(), Builder); - createVectorStore(Vec.value(), GEP, VType->getElementType(), IsVolatile, - Builder); + Builder.CreateAlignedStore(Vec.value(), GEP, + getAlignForIndex(Vec.index(), Stride, + VType->getElementType(), + MAlign), + IsVolatile); } return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * StoreVal.getNumVectors()); } /// Lower a store instruction with shape information. - void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, - bool IsVolatile, ShapeInfo Shape) { + void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, + Value *Stride, bool IsVolatile, ShapeInfo Shape) { IRBuilder<> Builder(Inst); auto StoreVal = getMatrix(Matrix, Shape, Builder); finalizeLowering(Inst, - storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, + storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile, Builder), Builder); } @@ -902,7 +916,7 @@ Value *Matrix = Inst->getArgOperand(0); Value *Ptr = Inst->getArgOperand(1); Value *Stride = Inst->getArgOperand(2); - LowerStore(Inst, Matrix, Ptr, Stride, + LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, cast(Inst->getArgOperand(3))->isOne(), {Inst->getArgOperand(4), Inst->getArgOperand(5)}); } @@ -1215,16 +1229,18 @@ for (unsigned K = 0; K < M; K += TileSize) { const unsigned TileM = std::min(M - K, unsigned(TileSize)); - MatrixTy A = loadMatrix(APtr, LoadOp0->isVolatile(), LShape, - Builder.getInt64(I), Builder.getInt64(K), - {TileR, TileM}, EltType, Builder); - MatrixTy B = loadMatrix(BPtr, LoadOp1->isVolatile(), RShape, - Builder.getInt64(K), Builder.getInt64(J), - {TileM, TileC}, EltType, Builder); + MatrixTy A = + loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), + LShape, Builder.getInt64(I), Builder.getInt64(K), + {TileR, TileM}, EltType, Builder); + MatrixTy B = + loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), + RShape, Builder.getInt64(K), Builder.getInt64(J), + {TileM, TileC}, EltType, Builder); emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); } - storeMatrix(Res, CPtr, Store->isVolatile(), {R, M}, Builder.getInt64(I), - Builder.getInt64(J), EltType, Builder); + storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, + Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); } // Mark eliminated instructions as fused and remove them. @@ -1337,8 +1353,9 @@ if (I == ShapeMap.end()) return false; - LowerLoad(Inst, Ptr, Builder.getInt64(I->second.getStride()), - Inst->isVolatile(), I->second); + LowerLoad(Inst, Ptr, Inst->getAlign(), + Builder.getInt64(I->second.getStride()), Inst->isVolatile(), + I->second); return true; } @@ -1348,8 +1365,9 @@ if (I == ShapeMap.end()) return false; - LowerStore(Inst, StoredVal, Ptr, Builder.getInt64(I->second.getStride()), - Inst->isVolatile(), I->second); + LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), + Builder.getInt64(I->second.getStride()), Inst->isVolatile(), + I->second); return true; } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/const-gep.ll @@ -76,9 +76,9 @@ %c.addr = alloca i32, align 4 store i32 %r, i32* %r.addr, align 4 store i32 %c, i32* %c.addr, align 4 - %0 = load <4 x double>, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 0), align 16 + %0 = load <4 x double>, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 0), align 8 %mul = call <4 x double> @llvm.matrix.multiply(<4 x double> %0, <4 x double> %0, i32 2, i32 2, i32 2) - store <4 x double> %0, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 2), align 16 + store <4 x double> %0, <4 x double>* getelementptr inbounds ([5 x <4 x double>], [5 x <4 x double>]* @foo, i64 0, i64 2), align 8 ret void } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/load-align-volatile.ll @@ -51,7 +51,7 @@ ; CHECK-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]] ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START]] ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[VEC_GEP]] to <3 x double>* -; CHECK-NEXT: load <3 x double>, <3 x double>* [[VEC_CAST]], align 8 +; CHECK-NEXT: load <3 x double>, <3 x double>* [[VEC_CAST]], align 32 ; CHECK-NEXT: [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]] ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START1]] ; CHECK-NEXT: [[VEC_CAST3:%.*]] = bitcast double* [[VEC_GEP2]] to <3 x double>* @@ -74,15 +74,15 @@ ; CHECK-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]] ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START]] ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[VEC_GEP]] to <3 x double>* -; CHECK-NEXT: load <3 x double>, <3 x double>* [[VEC_CAST]], align 8 +; CHECK-NEXT: load <3 x double>, <3 x double>* [[VEC_CAST]], align 2 ; CHECK-NEXT: [[VEC_START1:%.*]] = mul i64 1, [[STRIDE]] ; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START1]] ; CHECK-NEXT: [[VEC_CAST3:%.*]] = bitcast double* [[VEC_GEP2]] to <3 x double>* -; CHECK-NEXT: load <3 x double>, <3 x double>* [[VEC_CAST3]], align 8 +; CHECK-NEXT: load <3 x double>, <3 x double>* [[VEC_CAST3]], align 2 ; CHECK-NEXT: [[VEC_START5:%.*]] = mul i64 2, [[STRIDE]] ; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP0]], i64 [[VEC_START5]] ; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <3 x double>* -; CHECK-NEXT: load <3 x double>, <3 x double>* [[VEC_CAST7]], align 8 +; CHECK-NEXT: load <3 x double>, <3 x double>* [[VEC_CAST7]], align 2 ; CHECK-NOT: = load ; entry: @@ -95,10 +95,10 @@ ; CHECK-LABEL: @load_align2_multiply( ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x double>* [[IN:%.*]] to double* ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[TMP1]] to <2 x double>* -; CHECK-NEXT: load <2 x double>, <2 x double>* [[VEC_CAST]], align 8 +; CHECK-NEXT: load <2 x double>, <2 x double>* [[VEC_CAST]], align 2 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP1]], i64 2 ; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>* -; CHECK-NEXT: load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8 +; CHECK-NEXT: load <2 x double>, <2 x double>* [[VEC_CAST1]], align 2 ; CHECK-NOT: = load ; %in.m = load <4 x double>, <4 x double>* %in, align 2 @@ -111,13 +111,13 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[TMP0:%.*]] = bitcast <6 x float>* [[IN:%.*]] to float* ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast float* [[TMP0]] to <2 x float>* -; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST]], align 4 +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST]], align 16 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, float* [[TMP0]], i64 2 ; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast float* [[VEC_GEP]] to <2 x float>* -; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST1]], align 4 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST1]], align 8 ; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, float* [[TMP0]], i64 4 ; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast float* [[VEC_GEP3]] to <2 x float>* -; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST4]], align 4 +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, <2 x float>* [[VEC_CAST4]], align 16 ; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]], <4 x i32> ; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x float> [[COL_LOAD5]], <2 x float> undef, <4 x i32> ; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x float> [[TMP1]], <4 x float> [[TMP2]], <6 x i32> diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/store-align-volatile.ll @@ -43,7 +43,7 @@ ; CHECK-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]] ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT:%.*]], i64 [[VEC_START]] ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>* -; CHECK-NEXT: store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4 +; CHECK-NEXT: store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 32 ; CHECK-NEXT: [[VEC_START2:%.*]] = mul i64 1, [[STRIDE]] ; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i32, i32* [[OUT]], i64 [[VEC_START2]] ; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast i32* [[VEC_GEP3]] to <3 x i32>* @@ -61,11 +61,11 @@ ; CHECK-NEXT: [[VEC_START:%.*]] = mul i64 0, [[STRIDE:%.*]] ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT:%.*]], i64 [[VEC_START]] ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast i32* [[VEC_GEP]] to <3 x i32>* -; CHECK-NEXT: store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 4 +; CHECK-NEXT: store volatile <3 x i32> [[SPLIT]], <3 x i32>* [[VEC_CAST]], align 2 ; CHECK-NEXT: [[VEC_START2:%.*]] = mul i64 1, [[STRIDE]] ; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i32, i32* [[OUT]], i64 [[VEC_START2]] ; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast i32* [[VEC_GEP3]] to <3 x i32>* -; CHECK-NEXT: store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 4 +; CHECK-NEXT: store volatile <3 x i32> [[SPLIT1]], <3 x i32>* [[VEC_CAST4]], align 2 ; CHECK-NEXT: ret void ; call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* align 2 %out, i64 %stride, i1 true, i32 3, i32 2) @@ -76,10 +76,10 @@ ; CHECK-LABEL: @multiply_store_align16_stride8( ; CHECK: [[TMP29:%.*]] = bitcast <4 x i32>* %out to i32* ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast i32* [[TMP29]] to <2 x i32>* -; CHECK-NEXT: store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST]], align 4 +; CHECK-NEXT: store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST]], align 16 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, i32* [[TMP29]], i64 2 ; CHECK-NEXT: [[VEC_CAST25:%.*]] = bitcast i32* [[VEC_GEP]] to <2 x i32>* -; CHECK-NEXT: store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST25]], align 4 +; CHECK-NEXT: store <2 x i32> {{.*}}, <2 x i32>* [[VEC_CAST25]], align 8 ; CHECK-NEXT: ret void ; %res = call <4 x i32> @llvm.matrix.multiply(<4 x i32> %in, <4 x i32> %in, i32 2, i32 2, i32 2) @@ -93,13 +93,13 @@ ; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x i32> [[IN]], <6 x i32> undef, <2 x i32> ; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <6 x i32> [[IN]], <6 x i32> undef, <2 x i32> ; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast i32* [[OUT:%.*]] to <2 x i32>* -; CHECK-NEXT: store <2 x i32> [[SPLIT]], <2 x i32>* [[VEC_CAST]], align 4 +; CHECK-NEXT: store <2 x i32> [[SPLIT]], <2 x i32>* [[VEC_CAST]], align 8 ; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, i32* [[OUT]], i64 3 ; CHECK-NEXT: [[VEC_CAST3:%.*]] = bitcast i32* [[VEC_GEP]] to <2 x i32>* ; CHECK-NEXT: store <2 x i32> [[SPLIT1]], <2 x i32>* [[VEC_CAST3]], align 4 ; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr i32, i32* [[OUT]], i64 6 ; CHECK-NEXT: [[VEC_CAST5:%.*]] = bitcast i32* [[VEC_GEP4]] to <2 x i32>* -; CHECK-NEXT: store <2 x i32> [[SPLIT2]], <2 x i32>* [[VEC_CAST5]], align 4 +; CHECK-NEXT: store <2 x i32> [[SPLIT2]], <2 x i32>* [[VEC_CAST5]], align 8 ; CHECK-NEXT: ret void ; call void @llvm.matrix.column.major.store(<6 x i32> %in, i32* align 8 %out, i64 3, i1 false, i32 2, i32 3)