diff --git a/llvm/include/llvm/Support/TypeSize.h b/llvm/include/llvm/Support/TypeSize.h --- a/llvm/include/llvm/Support/TypeSize.h +++ b/llvm/include/llvm/Support/TypeSize.h @@ -131,6 +131,20 @@ return { MinSize / RHS, IsScalable }; } + TypeSize &operator-=(TypeSize RHS) { + assert(IsScalable == RHS.IsScalable && + "Subtraction using mixed scalable and fixed types"); + MinSize -= RHS.MinSize; + return *this; + } + + TypeSize &operator+=(TypeSize RHS) { + assert(IsScalable == RHS.IsScalable && + "Addition using mixed scalable and fixed types"); + MinSize += RHS.MinSize; + return *this; + } + // Return the minimum size with the assumption that the size is exact. // Use in places where a scalable size doesn't make sense (e.g. non-vector // types, or vectors in backends which don't support scalable vectors). diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -780,8 +780,8 @@ // Helper function for incrementing the pointer when splitting // memory operations - void IncrementPointer(MemSDNode *N, EVT MemVT, - MachinePointerInfo &MPI, SDValue &Ptr); + void IncrementPointer(MemSDNode *N, EVT MemVT, MachinePointerInfo &MPI, + SDValue &Ptr, uint64_t *ScaledOffset = nullptr); // Vector Result Splitting: <128 x ty> -> 2 x <64 x ty>. void SplitVectorResult(SDNode *N, unsigned ResNo); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -984,16 +984,20 @@ } void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT, - MachinePointerInfo &MPI, - SDValue &Ptr) { + MachinePointerInfo &MPI, SDValue &Ptr, + uint64_t *ScaledOffset) { SDLoc DL(N); unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinSize() / 8; if (MemVT.isScalableVector()) { + SDNodeFlags Flags; SDValue BytesIncrement = DAG.getVScale( DL, Ptr.getValueType(), APInt(Ptr.getValueSizeInBits().getFixedSize(), IncrementSize)); MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace()); + Flags.setNoUnsignedWrap(true); + if (ScaledOffset) + *ScaledOffset += IncrementSize; Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement); } else { MPI = N->getPointerInfo().getWithOffset(IncrementSize); @@ -4844,7 +4848,7 @@ // If we have one element to load/store, return it. EVT RetVT = WidenEltVT; - if (Width == WidenEltWidth) + if (!Scalable && Width == WidenEltWidth) return RetVT; // See if there is larger legal integer than the element type to load/store. @@ -5139,55 +5143,62 @@ SDLoc dl(ST); EVT StVT = ST->getMemoryVT(); - unsigned StWidth = StVT.getSizeInBits(); + TypeSize StWidth = StVT.getSizeInBits(); EVT ValVT = ValOp.getValueType(); - unsigned ValWidth = ValVT.getSizeInBits(); + TypeSize ValWidth = ValVT.getSizeInBits(); EVT ValEltVT = ValVT.getVectorElementType(); - unsigned ValEltWidth = ValEltVT.getSizeInBits(); + unsigned ValEltWidth = ValEltVT.getSizeInBits().getFixedSize(); assert(StVT.getVectorElementType() == ValEltVT); + assert(StVT.isScalableVector() == ValVT.isScalableVector() && + "Mismatch between store and value types"); int Idx = 0; // current index to store - unsigned Offset = 0; // offset from base to store - while (StWidth != 0) { + + MachinePointerInfo MPI = ST->getPointerInfo(); + uint64_t ScaledOffset = 0; + while (StWidth.isNonZero()) { // Find the largest vector type we can store with. - EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT); - unsigned NewVTWidth = NewVT.getSizeInBits(); - unsigned Increment = NewVTWidth / 8; + EVT NewVT = FindMemType(DAG, TLI, StWidth.getKnownMinSize(), ValVT); + TypeSize NewVTWidth = NewVT.getSizeInBits(); + if (NewVT.isVector()) { - unsigned NumVTElts = NewVT.getVectorNumElements(); + unsigned NumVTElts = NewVT.getVectorMinNumElements(); do { + Align NewAlign = ScaledOffset == 0 + ? ST->getOriginalAlign() + : commonAlignment(ST->getAlign(), ScaledOffset); SDValue EOp = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, ValOp, DAG.getVectorIdxConstant(Idx, dl)); - StChain.push_back(DAG.getStore( - Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset), - ST->getOriginalAlign(), MMOFlags, AAInfo)); + SDValue PartStore = DAG.getStore(Chain, dl, EOp, BasePtr, MPI, NewAlign, + MMOFlags, AAInfo); + StChain.push_back(PartStore); + StWidth -= NewVTWidth; - Offset += Increment; Idx += NumVTElts; - BasePtr = - DAG.getObjectPtrOffset(dl, BasePtr, TypeSize::Fixed(Increment)); - } while (StWidth != 0 && StWidth >= NewVTWidth); + IncrementPointer(cast(PartStore), NewVT, MPI, BasePtr, + &ScaledOffset); + } while (StWidth.isNonZero() && StWidth >= NewVTWidth); } else { // Cast the vector to the scalar type we can store. - unsigned NumElts = ValWidth / NewVTWidth; + unsigned NumElts = ValWidth.getFixedSize() / NewVTWidth.getFixedSize(); EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewVT, NumElts); SDValue VecOp = DAG.getNode(ISD::BITCAST, dl, NewVecVT, ValOp); // Readjust index position based on new vector type. - Idx = Idx * ValEltWidth / NewVTWidth; + Idx = Idx * ValEltWidth / NewVTWidth.getFixedSize(); do { SDValue EOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, NewVT, VecOp, DAG.getVectorIdxConstant(Idx++, dl)); - StChain.push_back(DAG.getStore( - Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset), - ST->getOriginalAlign(), MMOFlags, AAInfo)); + SDValue PartStore = + DAG.getStore(Chain, dl, EOp, BasePtr, MPI, ST->getOriginalAlign(), + MMOFlags, AAInfo); + StChain.push_back(PartStore); + StWidth -= NewVTWidth; - Offset += Increment; - BasePtr = - DAG.getObjectPtrOffset(dl, BasePtr, TypeSize::Fixed(Increment)); - } while (StWidth != 0 && StWidth >= NewVTWidth); + IncrementPointer(cast(PartStore), NewVT, MPI, BasePtr); + } while (StWidth.isNonZero() && StWidth >= NewVTWidth); // Restore index back to be relative to the original widen element type. - Idx = Idx * NewVTWidth / ValEltWidth; + Idx = Idx * NewVTWidth.getFixedSize() / ValEltWidth; } } } @@ -5210,8 +5221,13 @@ // It must be true that the wide vector type is bigger than where we need to // store. assert(StVT.isVector() && ValOp.getValueType().isVector()); + assert(StVT.isScalableVector() == ValOp.getValueType().isScalableVector()); assert(StVT.bitsLT(ValOp.getValueType())); + if (StVT.isScalableVector()) + report_fatal_error("Generating widen scalable vector truncating stores not " + "yet supported"); + // For truncating stores, we can not play the tricks of chopping legal vector // types and bitcast it to the right type. Instead, we unroll the store. EVT StEltVT = StVT.getVectorElementType(); diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll @@ -437,6 +437,69 @@ } +; Stores (tuples) + +define void @store_i64_tuple3(* %out, %in1, %in2, %in3) { +; CHECK-LABEL: store_i64_tuple3 +; CHECK: st1d { z2.d }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1d { z0.d }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64( %in1, %in2, %in3) + store %tuple, * %out + ret void +} + +define void @store_i64_tuple4(* %out, %in1, %in2, %in3, %in4) { +; CHECK-LABEL: store_i64_tuple4 +; CHECK: st1d { z3.d }, p0, [x0, #3, mul vl] +; CHECK-NEXT: st1d { z2.d }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1d { z0.d }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64( %in1, %in2, %in3, %in4) + store %tuple, * %out + ret void +} + +define void @store_i16_tuple2(* %out, %in1, %in2) { +; CHECK-LABEL: store_i16_tuple2 +; CHECK: st1h { z1.h }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1h { z0.h }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16( %in1, %in2) + store %tuple, * %out + ret void +} + +define void @store_i16_tuple3(* %out, %in1, %in2, %in3) { +; CHECK-LABEL: store_i16_tuple3 +; CHECK: st1h { z2.h }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1h { z1.h }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1h { z0.h }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16( %in1, %in2, %in3) + store %tuple, * %out + ret void +} + +define void @store_f32_tuple3(* %out, %in1, %in2, %in3) { +; CHECK-LABEL: store_f32_tuple3 +; CHECK: st1w { z2.s }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1w { z0.s }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32( %in1, %in2, %in3) + store %tuple, * %out + ret void +} + +define void @store_f32_tuple4(* %out, %in1, %in2, %in3, %in4) { +; CHECK-LABEL: store_f32_tuple4 +; CHECK: st1w { z3.s }, p0, [x0, #3, mul vl] +; CHECK-NEXT: st1w { z2.s }, p0, [x0, #2, mul vl] +; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl] +; CHECK-NEXT: st1w { z0.s }, p0, [x0] + %tuple = tail call @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32( %in1, %in2, %in3, %in4) + store %tuple, * %out + ret void +} + declare void @llvm.aarch64.sve.st2.nxv16i8(, , , i8*) declare void @llvm.aarch64.sve.st2.nxv8i16(, , , i16*) declare void @llvm.aarch64.sve.st2.nxv4i32(, , , i32*) @@ -473,5 +536,14 @@ declare void @llvm.aarch64.sve.stnt1.nxv4f32(, , float*) declare void @llvm.aarch64.sve.stnt1.nxv2f64(, , double*) +declare @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(, , ) +declare @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(, , , ) + +declare @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16(, ) +declare @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16(, , ) + +declare @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32(, , ) +declare @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32(, , , ) + ; +bf16 is required for the bfloat version. attributes #0 = { "target-features"="+sve,+bf16" } diff --git a/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll b/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll --- a/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll +++ b/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll @@ -133,3 +133,37 @@ store %splat, * %out ret void } + +; Splat stores of unusual FP scalable vector types + +define void @store_nxv6f32(* %out) { +; CHECK-LABEL: store_nxv6f32: +; CHECK: // %bb.0: +; CHECK-NEXT: fmov z0.s, #1.00000000 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: st1w { z0.s }, p0, [x0] +; CHECK-NEXT: uunpklo z0.d, z0.s +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: st1w { z0.d }, p0, [x0, #2, mul vl] +; CHECK-NEXT: ret + %ins = insertelement undef, float 1.0, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + store %splat, * %out + ret void +} + +define void @store_nxv12f16(* %out) { +; CHECK-LABEL: store_nxv12f16: +; CHECK: // %bb.0: +; CHECK-NEXT: fmov z0.h, #1.00000000 +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: st1h { z0.h }, p0, [x0] +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: st1h { z0.s }, p0, [x0, #2, mul vl] +; CHECK-NEXT: ret + %ins = insertelement undef, half 1.0, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + store %splat, * %out + ret void +}