Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -983,11 +983,14 @@ unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinSize() / 8; if (MemVT.isScalableVector()) { + SDNodeFlags Flags; SDValue BytesIncrement = DAG.getVScale( DL, Ptr.getValueType(), APInt(Ptr.getValueSizeInBits().getFixedSize(), IncrementSize)); MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace()); - Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement); + Flags.setNoUnsignedWrap(true); + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement, + Flags); } else { MPI = N->getPointerInfo().getWithOffset(IncrementSize); // Increment the pointer to the other half. @@ -4820,7 +4823,7 @@ // If we have one element to load/store, return it. EVT RetVT = WidenEltVT; - if (Width == WidenEltWidth) + if (!Scalable && Width == WidenEltWidth) return RetVT; // See if there is larger legal integer than the element type to load/store. @@ -5114,19 +5117,45 @@ SDLoc dl(ST); EVT StVT = ST->getMemoryVT(); - unsigned StWidth = StVT.getSizeInBits(); EVT ValVT = ValOp.getValueType(); - unsigned ValWidth = ValVT.getSizeInBits(); EVT ValEltVT = ValVT.getVectorElementType(); - unsigned ValEltWidth = ValEltVT.getSizeInBits(); assert(StVT.getVectorElementType() == ValEltVT); + assert(StVT.isScalableVector() == ValVT.isScalableVector() && + "Mismatch between store and value types"); + unsigned StWidth = StVT.getSizeInBits().getKnownMinSize(); + if (StVT.isScalableVector()) { + // Find the largest vector type we can store with. + EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT); + assert(NewVT.isScalableVector() && + "Using fixed types to store scalable vectors"); + TypeSize Increment = NewVT.getSizeInBits(); + assert(StWidth % Increment.getKnownMinSize() == 0 && + "Gaps in widen scalable vector stores"); + unsigned NumStores = StWidth / Increment.getKnownMinSize(); + + MachinePointerInfo MPI = ST->getPointerInfo(); + for (unsigned i = 0; i < NumStores; i++) { + unsigned Idx = i * NewVT.getVectorMinNumElements(); + SDValue EOp = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, ValOp, + DAG.getVectorIdxConstant(Idx, dl)); + SDValue TmpSt = DAG.getStore(Chain, dl, EOp, BasePtr, MPI, + ST->getOriginalAlign(), MMOFlags, AAInfo); + StChain.push_back(TmpSt); + IncrementPointer(cast(TmpSt), NewVT, MPI, BasePtr); + } + + return; + } + + unsigned ValEltWidth = ValEltVT.getSizeInBits().getFixedSize(); + unsigned ValWidth = ValVT.getSizeInBits().getFixedSize(); int Idx = 0; // current index to store unsigned Offset = 0; // offset from base to store while (StWidth != 0) { // Find the largest vector type we can store with. EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT); - unsigned NewVTWidth = NewVT.getSizeInBits(); + unsigned NewVTWidth = NewVT.getSizeInBits().getFixedSize(); unsigned Increment = NewVTWidth / 8; if (NewVT.isVector()) { unsigned NumVTElts = NewVT.getVectorNumElements(); @@ -5183,8 +5212,13 @@ // It must be true that the wide vector type is bigger than where we need to // store. assert(StVT.isVector() && ValOp.getValueType().isVector()); + assert(StVT.isScalableVector() == ValOp.getValueType().isScalableVector()); assert(StVT.bitsLT(ValOp.getValueType())); + if (StVT.isScalableVector()) + report_fatal_error("Generating widen scalable vector truncating stores not " + "yet supported"); + // For truncating stores, we can not play the tricks of chopping legal vector // types and bitcast it to the right type. Instead, we unroll the store. EVT StEltVT = StVT.getVectorElementType(); Index: llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll +++ llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll @@ -437,6 +437,38 @@ } +; Stores (tuples) + +define void @store_i64_tuple2(* %out, %in1, %in2) { +; CHECK-LABEL: store_i64_tuple2 +; CHECK: st1d { z1.d }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1d { z0.d }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create2.nxv4i64.nxv2i64( %in1, %in2) + store %tuple, * %out + ret void +} + +define void @store_i64_tuple3(* %out, %in1, %in2, %in3) { +; CHECK-LABEL: store_i64_tuple3 +; CHECK: st1d { z2.d }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1d { z0.d }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64( %in1, %in2, %in3) + store %tuple, * %out + ret void +} + +define void @store_i64_tuple4(* %out, %in1, %in2, %in3, %in4) { +; CHECK-LABEL: store_i64_tuple4 +; CHECK: st1d { z3.d }, p0, [x0, #3, mul vl] +; CHECK-NEXT: st1d { z2.d }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1d { z0.d }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64( %in1, %in2, %in3, %in4) + store %tuple, * %out + ret void +} + declare void @llvm.aarch64.sve.st2.nxv16i8(, , , i8*) declare void @llvm.aarch64.sve.st2.nxv8i16(, , , i16*) declare void @llvm.aarch64.sve.st2.nxv4i32(, , , i32*) @@ -473,5 +505,9 @@ declare void @llvm.aarch64.sve.stnt1.nxv4f32(, , float*) declare void @llvm.aarch64.sve.stnt1.nxv2f64(, , double*) +declare @llvm.aarch64.sve.tuple.create2.nxv4i64.nxv2i64(, ) +declare @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(, , ) +declare @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(, , , ) + ; +bf16 is required for the bfloat version. attributes #0 = { "target-features"="+sve,+bf16" }