Index: llvm/include/llvm/Support/TypeSize.h =================================================================== --- llvm/include/llvm/Support/TypeSize.h +++ llvm/include/llvm/Support/TypeSize.h @@ -131,6 +131,20 @@ return { MinSize / RHS, IsScalable }; } + TypeSize &operator-=(TypeSize RHS) { + assert(IsScalable == RHS.IsScalable && + "Subtraction using mixed scalable and fixed types"); + MinSize -= RHS.MinSize; + return *this; + } + + TypeSize &operator+=(TypeSize RHS) { + assert(IsScalable == RHS.IsScalable && + "Addition using mixed scalable and fixed types"); + MinSize += RHS.MinSize; + return *this; + } + // Return the minimum size with the assumption that the size is exact. // Use in places where a scalable size doesn't make sense (e.g. non-vector // types, or vectors in backends which don't support scalable vectors). Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -983,10 +983,12 @@ unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinSize() / 8; if (MemVT.isScalableVector()) { + SDNodeFlags Flags; SDValue BytesIncrement = DAG.getVScale( DL, Ptr.getValueType(), APInt(Ptr.getValueSizeInBits().getFixedSize(), IncrementSize)); MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace()); + Flags.setNoUnsignedWrap(true); Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement); } else { MPI = N->getPointerInfo().getWithOffset(IncrementSize); @@ -4820,7 +4822,7 @@ // If we have one element to load/store, return it. EVT RetVT = WidenEltVT; - if (Width == WidenEltWidth) + if (!Scalable && Width == WidenEltWidth) return RetVT; // See if there is larger legal integer than the element type to load/store. @@ -5114,53 +5116,55 @@ SDLoc dl(ST); EVT StVT = ST->getMemoryVT(); - unsigned StWidth = StVT.getSizeInBits(); + TypeSize StWidth = StVT.getSizeInBits(); EVT ValVT = ValOp.getValueType(); - unsigned ValWidth = ValVT.getSizeInBits(); + TypeSize ValWidth = ValVT.getSizeInBits(); EVT ValEltVT = ValVT.getVectorElementType(); - unsigned ValEltWidth = ValEltVT.getSizeInBits(); + unsigned ValEltWidth = ValEltVT.getSizeInBits().getFixedSize(); assert(StVT.getVectorElementType() == ValEltVT); + assert(StVT.isScalableVector() == ValVT.isScalableVector() && + "Mismatch between store and value types"); int Idx = 0; // current index to store - unsigned Offset = 0; // offset from base to store - while (StWidth != 0) { + + MachinePointerInfo MPI = ST->getPointerInfo(); + while (StWidth.isNonZero()) { // Find the largest vector type we can store with. - EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT); - unsigned NewVTWidth = NewVT.getSizeInBits(); - unsigned Increment = NewVTWidth / 8; + EVT NewVT = FindMemType(DAG, TLI, StWidth.getKnownMinSize(), ValVT); + TypeSize NewVTWidth = NewVT.getSizeInBits(); + TypeSize Increment = NewVTWidth / 8; if (NewVT.isVector()) { - unsigned NumVTElts = NewVT.getVectorNumElements(); + unsigned NumVTElts = NewVT.getVectorMinNumElements(); do { SDValue EOp = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, ValOp, DAG.getVectorIdxConstant(Idx, dl)); - StChain.push_back(DAG.getStore( - Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset), - ST->getOriginalAlign(), MMOFlags, AAInfo)); + SDValue PartStore = DAG.getStore(Chain, dl, EOp, BasePtr, MPI, + ST->getOriginalAlign(), MMOFlags, AAInfo); + StChain.push_back(PartStore); + StWidth -= NewVTWidth; - Offset += Increment; Idx += NumVTElts; - - BasePtr = DAG.getObjectPtrOffset(dl, BasePtr, Increment); - } while (StWidth != 0 && StWidth >= NewVTWidth); + IncrementPointer(cast(PartStore), NewVT, MPI, BasePtr); + } while (StWidth.isNonZero() && StWidth >= NewVTWidth); } else { // Cast the vector to the scalar type we can store. - unsigned NumElts = ValWidth / NewVTWidth; + unsigned NumElts = ValWidth.getFixedSize() / NewVTWidth.getFixedSize(); EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewVT, NumElts); SDValue VecOp = DAG.getNode(ISD::BITCAST, dl, NewVecVT, ValOp); // Readjust index position based on new vector type. - Idx = Idx * ValEltWidth / NewVTWidth; + Idx = Idx * ValEltWidth / NewVTWidth.getFixedSize(); do { SDValue EOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, NewVT, VecOp, DAG.getVectorIdxConstant(Idx++, dl)); - StChain.push_back(DAG.getStore( - Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset), - ST->getOriginalAlign(), MMOFlags, AAInfo)); + SDValue PartStore = DAG.getStore(Chain, dl, EOp, BasePtr, MPI, + ST->getOriginalAlign(), MMOFlags, AAInfo); + StChain.push_back(PartStore); + StWidth -= NewVTWidth; - Offset += Increment; - BasePtr = DAG.getObjectPtrOffset(dl, BasePtr, Increment); - } while (StWidth != 0 && StWidth >= NewVTWidth); + IncrementPointer(cast(PartStore), NewVT, MPI, BasePtr); + } while (StWidth.isNonZero() && StWidth >= NewVTWidth); // Restore index back to be relative to the original widen element type. - Idx = Idx * NewVTWidth / ValEltWidth; + Idx = Idx * NewVTWidth.getFixedSize() / ValEltWidth; } } } @@ -5183,8 +5187,13 @@ // It must be true that the wide vector type is bigger than where we need to // store. assert(StVT.isVector() && ValOp.getValueType().isVector()); + assert(StVT.isScalableVector() == ValOp.getValueType().isScalableVector()); assert(StVT.bitsLT(ValOp.getValueType())); + if (StVT.isScalableVector()) + report_fatal_error("Generating widen scalable vector truncating stores not " + "yet supported"); + // For truncating stores, we can not play the tricks of chopping legal vector // types and bitcast it to the right type. Instead, we unroll the store. EVT StEltVT = StVT.getVectorElementType(); Index: llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll +++ llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll @@ -437,6 +437,70 @@ } +; Stores (tuples) + +define void @store_i64_tuple3(* %out, %in1, %in2, %in3) { +; CHECK-LABEL: store_i64_tuple3 +; CHECK: st1d { z2.d }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1d { z0.d }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64( %in1, %in2, %in3) + store %tuple, * %out + ret void +} + +define void @store_i64_tuple4(* %out, %in1, %in2, %in3, %in4) { +; CHECK-LABEL: store_i64_tuple4 +; CHECK: st1d { z3.d }, p0, [x0, #3, mul vl] +; CHECK-NEXT: st1d { z2.d }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1d { z0.d }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64( %in1, %in2, %in3, %in4) + store %tuple, * %out + ret void +} + +define void @store_i16_tuple2(* %out, %in1, %in2) { +; CHECK-LABEL: store_i16_tuple2 +; CHECK: st1h { z1.h }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1h { z0.h }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16( %in1, %in2) + store %tuple, * %out + ret void +} + +define void @store_i16_tuple3(* %out, %in1, %in2, %in3) { +; CHECK-LABEL: store_i16_tuple3 +; CHECK: st1h { z2.h }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1h { z1.h }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1h { z0.h }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16( %in1, %in2, %in3) + store %tuple, * %out + ret void +} + +define void @store_f32_tuple3(* %out, %in1, %in2, %in3) { +; CHECK-LABEL: store_f32_tuple3 +; CHECK: st1w { z2.s }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1w { z0.s }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32( %in1, %in2, %in3) + store %tuple, * %out + ret void +} + +define void @store_f32_tuple4(* %out, %in1, %in2, %in3, %in4) { +; CHECK-LABEL: store_f32_tuple4 +; CHECK: st1w { z3.s }, p0, [x0, #3, mul vl] +; CHECK-NEXT: st1w { z2.s }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1w { z0.s }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32( %in1, %in2, %in3, %in4) + store %tuple, * %out + ret void +} + + declare void @llvm.aarch64.sve.st2.nxv16i8(, , , i8*) declare void @llvm.aarch64.sve.st2.nxv8i16(, , , i16*) declare void @llvm.aarch64.sve.st2.nxv4i32(, , , i32*) @@ -473,5 +537,14 @@ declare void @llvm.aarch64.sve.stnt1.nxv4f32(, , float*) declare void @llvm.aarch64.sve.stnt1.nxv2f64(, , double*) +declare @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(, , ) +declare @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(, , , ) + +declare @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16(, ) +declare @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16(, , ) + +declare @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32(, , ) +declare @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32(, , , ) + ; +bf16 is required for the bfloat version. attributes #0 = { "target-features"="+sve,+bf16" }