diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -790,6 +790,25 @@ LLVMPointerToElt<0>], [IntrArgMemOnly, NoCapture<2>]>; + class AdvSIMD_2Vec_PredStore_Intrinsic + : Intrinsic<[], + [llvm_anyvector_ty, LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrArgMemOnly, NoCapture<3>]>; + + class AdvSIMD_3Vec_PredStore_Intrinsic + : Intrinsic<[], + [llvm_anyvector_ty, LLVMMatchType<0>, LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrArgMemOnly, NoCapture<4>]>; + + class AdvSIMD_4Vec_PredStore_Intrinsic + : Intrinsic<[], + [llvm_anyvector_ty, LLVMMatchType<0>, LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrArgMemOnly, NoCapture<5>]>; + class AdvSIMD_SVE_Index_Intrinsic : Intrinsic<[llvm_anyvector_ty], [LLVMVectorElementType<0>, @@ -1292,7 +1311,10 @@ // Stores // -def int_aarch64_sve_st1 : AdvSIMD_1Vec_PredStore_Intrinsic; +def int_aarch64_sve_st1 : AdvSIMD_1Vec_PredStore_Intrinsic; +def int_aarch64_sve_st2 : AdvSIMD_2Vec_PredStore_Intrinsic; +def int_aarch64_sve_st3 : AdvSIMD_3Vec_PredStore_Intrinsic; +def int_aarch64_sve_st4 : AdvSIMD_4Vec_PredStore_Intrinsic; def int_aarch64_sve_stnt1 : AdvSIMD_1Vec_PredStore_Intrinsic; diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -223,6 +223,9 @@ /// unchanged; otherwise a REG_SEQUENCE value is returned. SDValue createDTuple(ArrayRef Vecs); SDValue createQTuple(ArrayRef Vecs); + // Form a sequence of SVE registers for instructions using list of vectors, + // e.g. structured loads and stores (ldN, stN). + SDValue createZTuple(ArrayRef Vecs); /// Generic helper for the createDTuple/createQTuple /// functions. Those should almost always be called instead. @@ -258,6 +261,7 @@ void SelectPostStore(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectStoreLane(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectPostStoreLane(SDNode *N, unsigned NumVecs, unsigned Opc); + void SelectPredicatedStore(SDNode *N, unsigned NumVecs, const unsigned Opc); bool tryBitfieldExtractOp(SDNode *N); bool tryBitfieldExtractOpFromSExt(SDNode *N); @@ -1192,6 +1196,16 @@ return createTuple(Regs, RegClassIDs, SubRegs); } +SDValue AArch64DAGToDAGISel::createZTuple(ArrayRef Regs) { + static const unsigned RegClassIDs[] = {AArch64::ZPR2RegClassID, + AArch64::ZPR3RegClassID, + AArch64::ZPR4RegClassID}; + static const unsigned SubRegs[] = {AArch64::zsub0, AArch64::zsub1, + AArch64::zsub2, AArch64::zsub3}; + + return createTuple(Regs, RegClassIDs, SubRegs); +} + SDValue AArch64DAGToDAGISel::createTuple(ArrayRef Regs, const unsigned RegClassIDs[], const unsigned SubRegs[]) { @@ -1414,6 +1428,23 @@ ReplaceNode(N, St); } +void AArch64DAGToDAGISel::SelectPredicatedStore(SDNode *N, unsigned NumVecs, + const unsigned Opc) { + SDLoc dl(N); + + // Form a REG_SEQUENCE to force register allocation. + SmallVector Regs(N->op_begin() + 2, N->op_begin() + 2 + NumVecs); + SDValue RegSeq = createZTuple(Regs); + + SDValue Ops[] = {RegSeq, N->getOperand(NumVecs + 2), // predicate + N->getOperand(NumVecs + 3), // address + CurDAG->getTargetConstant(0, dl, MVT::i64), // offset + N->getOperand(0)}; // chain + SDNode *St = CurDAG->getMachineNode(Opc, dl, N->getValueType(0), Ops); + + ReplaceNode(N, St); +} + bool AArch64DAGToDAGISel::SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base, SDValue &OffImm) { SDLoc dl(N); @@ -3877,6 +3908,54 @@ } break; } + case Intrinsic::aarch64_sve_st2: { + if (VT == MVT::nxv16i8) { + SelectPredicatedStore(Node, 2, AArch64::ST2B_IMM); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedStore(Node, 2, AArch64::ST2H_IMM); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedStore(Node, 2, AArch64::ST2W_IMM); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedStore(Node, 2, AArch64::ST2D_IMM); + return; + } + break; + } + case Intrinsic::aarch64_sve_st3: { + if (VT == MVT::nxv16i8) { + SelectPredicatedStore(Node, 3, AArch64::ST3B_IMM); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedStore(Node, 3, AArch64::ST3H_IMM); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedStore(Node, 3, AArch64::ST3W_IMM); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedStore(Node, 3, AArch64::ST3D_IMM); + return; + } + break; + } + case Intrinsic::aarch64_sve_st4: { + if (VT == MVT::nxv16i8) { + SelectPredicatedStore(Node, 4, AArch64::ST4B_IMM); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedStore(Node, 4, AArch64::ST4H_IMM); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedStore(Node, 4, AArch64::ST4W_IMM); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedStore(Node, 4, AArch64::ST4D_IMM); + return; + } + break; + } } break; } diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll @@ -1,6 +1,306 @@ ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s ; +; ST2B +; + +define void @st2b_i8( %v0, %v1, %pred, * %addr) { +; CHECK-LABEL: st2b_i8: +; CHECK: st2b { z0.b, z1.b }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st2.nxv16i8( %v0, + %v1, + %pred, + * %addr) + ret void +} + +; +; ST2H +; + +define void @st2h_i16( %v0, %v1, %pred, * %addr) { +; CHECK-LABEL: st2h_i16: +; CHECK: st2h { z0.h, z1.h }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st2.nxv8i16( %v0, + %v1, + %pred, + * %addr) + ret void +} + +define void @st2h_f16( %v0, %v1, %pred, * %addr) { +; CHECK-LABEL: st2h_f16: +; CHECK: st2h { z0.h, z1.h }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st2.nxv8f16( %v0, + %v1, + %pred, + * %addr) + ret void +} + +; +; ST2W +; + +define void @st2w_i32( %v0, %v1, %pred, * %addr) { +; CHECK-LABEL: st2w_i32: +; CHECK: st2w { z0.s, z1.s }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st2.nxv4i32( %v0, + %v1, + %pred, + * %addr) + ret void +} + +define void @st2w_f32( %v0, %v1, %pred, * %addr) { +; CHECK-LABEL: st2w_f32: +; CHECK: st2w { z0.s, z1.s }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st2.nxv4f32( %v0, + %v1, + %pred, + * %addr) + ret void +} + +; +; ST2D +; + +define void @st2d_i64( %v0, %v1, %pred, * %addr) { +; CHECK-LABEL: st2d_i64: +; CHECK: st2d { z0.d, z1.d }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st2.nxv2i64( %v0, + %v1, + %pred, + * %addr) + ret void +} + +define void @st2d_f64( %v0, %v1, %pred, * %addr) { +; CHECK-LABEL: st2d_f64: +; CHECK: st2d { z0.d, z1.d }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st2.nxv2f64( %v0, + %v1, + %pred, + * %addr) + ret void +} + +; +; ST3B +; + +define void @st3b_i8( %v0, %v1, %v2, %pred, * %addr) { +; CHECK-LABEL: st3b_i8: +; CHECK: st3b { z0.b, z1.b, z2.b }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st3.nxv16i8( %v0, + %v1, + %v2, + %pred, + * %addr) + ret void +} + +; +; ST3H +; + +define void @st3h_i16( %v0, %v1, %v2, %pred, * %addr) { +; CHECK-LABEL: st3h_i16: +; CHECK: st3h { z0.h, z1.h, z2.h }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st3.nxv8i16( %v0, + %v1, + %v2, + %pred, + * %addr) + ret void +} + +define void @st3h_f16( %v0, %v1, %v2, %pred, * %addr) { +; CHECK-LABEL: st3h_f16: +; CHECK: st3h { z0.h, z1.h, z2.h }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st3.nxv8f16( %v0, + %v1, + %v2, + %pred, + * %addr) + ret void +} + +; +; ST3W +; + +define void @st3w_i32( %v0, %v1, %v2, %pred, * %addr) { +; CHECK-LABEL: st3w_i32: +; CHECK: st3w { z0.s, z1.s, z2.s }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st3.nxv4i32( %v0, + %v1, + %v2, + %pred, + * %addr) + ret void +} + +define void @st3w_f32( %v0, %v1, %v2, %pred, * %addr) { +; CHECK-LABEL: st3w_f32: +; CHECK: st3w { z0.s, z1.s, z2.s }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st3.nxv4f32( %v0, + %v1, + %v2, + %pred, + * %addr) + ret void +} + +; +; ST3D +; + +define void @st3d_i64( %v0, %v1, %v2, %pred, * %addr) { +; CHECK-LABEL: st3d_i64: +; CHECK: st3d { z0.d, z1.d, z2.d }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st3.nxv2i64( %v0, + %v1, + %v2, + %pred, + * %addr) + ret void +} + +define void @st3d_f64( %v0, %v1, %v2, %pred, * %addr) { +; CHECK-LABEL: st3d_f64: +; CHECK: st3d { z0.d, z1.d, z2.d }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st3.nxv2f64( %v0, + %v1, + %v2, + %pred, + * %addr) + ret void +} + +; +; ST4B +; + +define void @st4b_i8( %v0, %v1, %v2, %v3, %pred, * %addr) { +; CHECK-LABEL: st4b_i8: +; CHECK: st4b { z0.b, z1.b, z2.b, z3.b }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st4.nxv16i8( %v0, + %v1, + %v2, + %v3, + %pred, + * %addr) + ret void +} + +; +; ST4H +; + +define void @st4h_i16( %v0, %v1, %v2, %v3, %pred, * %addr) { +; CHECK-LABEL: st4h_i16: +; CHECK: st4h { z0.h, z1.h, z2.h, z3.h }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st4.nxv8i16( %v0, + %v1, + %v2, + %v3, + %pred, + * %addr) + ret void +} + +define void @st4h_f16( %v0, %v1, %v2, %v3, %pred, * %addr) { +; CHECK-LABEL: st4h_f16: +; CHECK: st4h { z0.h, z1.h, z2.h, z3.h }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st4.nxv8f16( %v0, + %v1, + %v2, + %v3, + %pred, + * %addr) + ret void +} + +; +; ST4W +; + +define void @st4w_i32( %v0, %v1, %v2, %v3, %pred, * %addr) { +; CHECK-LABEL: st4w_i32: +; CHECK: st4w { z0.s, z1.s, z2.s, z3.s }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st4.nxv4i32( %v0, + %v1, + %v2, + %v3, + %pred, + * %addr) + ret void +} + +define void @st4w_f32( %v0, %v1, %v2, %v3, %pred, * %addr) { +; CHECK-LABEL: st4w_f32: +; CHECK: st4w { z0.s, z1.s, z2.s, z3.s }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st4.nxv4f32( %v0, + %v1, + %v2, + %v3, + %pred, + * %addr) + ret void +} + +; +; ST4D +; + +define void @st4d_i64( %v0, %v1, %v2, %v3, %pred, * %addr) { +; CHECK-LABEL: st4d_i64: +; CHECK: st4d { z0.d, z1.d, z2.d, z3.d }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st4.nxv2i64( %v0, + %v1, + %v2, + %v3, + %pred, + * %addr) + ret void +} + +define void @st4d_f64( %v0, %v1, %v2, %v3, %pred, * %addr) { +; CHECK-LABEL: st4d_f64: +; CHECK: st4d { z0.d, z1.d, z2.d, z3.d }, p0, [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sve.st4.nxv2f64( %v0, + %v1, + %v2, + %v3, + %pred, + * %addr) + ret void +} + +; ; STNT1B ; @@ -86,6 +386,31 @@ ret void } + +declare void @llvm.aarch64.sve.st2.nxv16i8(, , , *) +declare void @llvm.aarch64.sve.st2.nxv8i16(, , , *) +declare void @llvm.aarch64.sve.st2.nxv4i32(, , , *) +declare void @llvm.aarch64.sve.st2.nxv2i64(, , , *) +declare void @llvm.aarch64.sve.st2.nxv8f16(, , , *) +declare void @llvm.aarch64.sve.st2.nxv4f32(, , , *) +declare void @llvm.aarch64.sve.st2.nxv2f64(, , , *) + +declare void @llvm.aarch64.sve.st3.nxv16i8(, , , , *) +declare void @llvm.aarch64.sve.st3.nxv8i16(, , , , *) +declare void @llvm.aarch64.sve.st3.nxv4i32(, , , , *) +declare void @llvm.aarch64.sve.st3.nxv2i64(, , , , *) +declare void @llvm.aarch64.sve.st3.nxv8f16(, , , , *) +declare void @llvm.aarch64.sve.st3.nxv4f32(, , , , *) +declare void @llvm.aarch64.sve.st3.nxv2f64(, , , , *) + +declare void @llvm.aarch64.sve.st4.nxv16i8(, , , , , *) +declare void @llvm.aarch64.sve.st4.nxv8i16(, , , , , *) +declare void @llvm.aarch64.sve.st4.nxv4i32(, , , , , *) +declare void @llvm.aarch64.sve.st4.nxv2i64(, , , , , *) +declare void @llvm.aarch64.sve.st4.nxv8f16(, , , , , *) +declare void @llvm.aarch64.sve.st4.nxv4f32(, , , , , *) +declare void @llvm.aarch64.sve.st4.nxv2f64(, , , , , *) + declare void @llvm.aarch64.sve.stnt1.nxv16i8(, , i8*) declare void @llvm.aarch64.sve.stnt1.nxv8i16(, , i16*) declare void @llvm.aarch64.sve.stnt1.nxv4i32(, , i32*)