Index: include/llvm/CodeGen/ISDOpcodes.h =================================================================== --- include/llvm/CodeGen/ISDOpcodes.h +++ include/llvm/CodeGen/ISDOpcodes.h @@ -898,6 +898,12 @@ /// known nonzero constant. The only operand here is the chain. GET_DYNAMIC_AREA_OFFSET, + /// VSCALE(IMM) - Returns the runtime scaling factor used to calculate the + /// number of elements within a scalable vector. IMM is a constant integer + /// multiplier that is applied to the runtime value and is usually some + /// multiple of MVT.getVectorNumElements(). + VSCALE, + /// Generic reduction nodes. These nodes represent horizontal vector /// reduction operations, producing a scalar result. /// The STRICT variants perform reductions in sequential order. The first Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -875,6 +875,11 @@ return getNode(ISD::UNDEF, SDLoc(), VT); } + /// Return a node that represents the runtime scaling 'MulImm * RuntimeVL'. + SDValue getVScale(const SDLoc &DL, EVT VT, int64_t MulImm=1) { + return getNode(ISD::VSCALE, DL, VT, getConstant(MulImm, DL, VT)); + } + /// Return a GLOBAL_OFFSET_TABLE node. This does not have a useful SDLoc. SDValue getGLOBAL_OFFSET_TABLE(EVT VT) { return getNode(ISD::GLOBAL_OFFSET_TABLE, SDLoc(), VT); Index: include/llvm/Target/TargetSelectionDAG.td =================================================================== --- include/llvm/Target/TargetSelectionDAG.td +++ include/llvm/Target/TargetSelectionDAG.td @@ -313,6 +313,7 @@ def bb : SDNode<"ISD::BasicBlock", SDTOther , [], "BasicBlockSDNode">; def cond : SDNode<"ISD::CONDCODE" , SDTOther , [], "CondCodeSDNode">; def undef : SDNode<"ISD::UNDEF" , SDTUNDEF , []>; +def vscale : SDNode<"ISD::VSCALE" , SDTIntUnaryOp, []>; def globaladdr : SDNode<"ISD::GlobalAddress", SDTPtrLeaf, [], "GlobalAddressSDNode">; def tglobaladdr : SDNode<"ISD::TargetGlobalAddress", SDTPtrLeaf, [], Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -89,6 +89,7 @@ case ISD::TRUNCATE: Res = PromoteIntRes_TRUNCATE(N); break; case ISD::UNDEF: Res = PromoteIntRes_UNDEF(N); break; case ISD::VAARG: Res = PromoteIntRes_VAARG(N); break; + case ISD::VSCALE: Res = PromoteIntRes_VSCALE(N); break; case ISD::EXTRACT_SUBVECTOR: Res = PromoteIntRes_EXTRACT_SUBVECTOR(N); break; @@ -1049,6 +1050,13 @@ N->getValueType(0))); } +SDValue DAGTypeLegalizer::PromoteIntRes_VSCALE(SDNode *N) { + EVT VT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); + + int64_t MulImm = cast(N->getOperand(0))->getSExtValue(); + return DAG.getVScale(SDLoc(N), VT, MulImm); +} + SDValue DAGTypeLegalizer::PromoteIntRes_VAARG(SDNode *N) { SDValue Chain = N->getOperand(0); // Get the chain. SDValue Ptr = N->getOperand(1); // Get the pointer. Index: lib/CodeGen/SelectionDAG/LegalizeTypes.h =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -342,6 +342,7 @@ SDValue PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo); SDValue PromoteIntRes_UNDEF(SDNode *N); SDValue PromoteIntRes_VAARG(SDNode *N); + SDValue PromoteIntRes_VSCALE(SDNode *N); SDValue PromoteIntRes_XMULO(SDNode *N, unsigned ResNo); SDValue PromoteIntRes_ADDSUBSAT(SDNode *N); SDValue PromoteIntRes_MULFIX(SDNode *N); Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5089,11 +5089,19 @@ if (N2C && N2C->isNullValue()) return N1; break; + case ISD::MUL: + assert(VT.isInteger() && "This operator does not apply to FP types!"); + assert(N1.getValueType() == N2.getValueType() && + N1.getValueType() == VT && "Binary operator types must match!"); + if (N2C && (N1.getOpcode() == ISD::VSCALE)) { + int64_t MulImm = cast(N1->getOperand(0))->getSExtValue(); + return getVScale(DL, VT, MulImm * N2C->getSExtValue()); + } + break; case ISD::UDIV: case ISD::UREM: case ISD::MULHU: case ISD::MULHS: - case ISD::MUL: case ISD::SDIV: case ISD::SREM: case ISD::SMIN: @@ -5126,6 +5134,11 @@ "Invalid FCOPYSIGN!"); break; case ISD::SHL: + if (N2C && (N1.getOpcode() == ISD::VSCALE)) { + int64_t MulImm = cast(N1->getOperand(0))->getSExtValue(); + return getVScale(DL, VT, MulImm << N2C->getSExtValue()); + } + LLVM_FALLTHROUGH; case ISD::SRA: case ISD::SRL: if (SDValue V = simplifyShift(N1, N2)) Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1461,6 +1461,9 @@ if (isa(C) && !V->getType()->isAggregateType()) return DAG.getUNDEF(VT); + if (isa(C)) + return DAG.getVScale(getCurSDLoc(), VT); + if (const ConstantExpr *CE = dyn_cast(C)) { visit(CE->getOpcode(), *CE); SDValue N1 = NodeMap[V]; Index: lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -170,6 +170,7 @@ case ISD::CopyToReg: return "CopyToReg"; case ISD::CopyFromReg: return "CopyFromReg"; case ISD::UNDEF: return "undef"; + case ISD::VSCALE: return "vscale"; case ISD::MERGE_VALUES: return "merge_values"; case ISD::INLINEASM: return "inlineasm"; case ISD::INLINEASM_BR: return "inlineasm_br"; Index: lib/Target/AArch64/AArch64ISelDAGToDAG.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -140,6 +140,23 @@ return SelectAddrModeXRO(N, Width / 8, Base, Offset, SignExtend, DoShift); } + // Returns a suitable CNT/INC/DEC/RDVL multiplier to calculate VSCALE*N. + template + bool SelectRDVLImm(SDValue N, SDValue &Imm) { + if (!isa(N)) + return false; + + int64_t MulImm = cast(N)->getSExtValue(); + if ((MulImm % std::abs(Scale)) == 0) { + int64_t RDVLImm = MulImm / Scale; + if ((RDVLImm >= Low) && (RDVLImm <= High)) { + Imm = CurDAG->getTargetConstant(RDVLImm, SDLoc(N), MVT::i32); + return true; + } + } + + return false; + } /// Form sequences of consecutive 64/128-bit registers for use in NEON /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have Index: lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.h +++ lib/Target/AArch64/AArch64ISelLowering.h @@ -681,6 +681,7 @@ SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFSINCOS(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerVSCALE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerATOMIC_LOAD_SUB(SDValue Op, SelectionDAG &DAG) const; SDValue LowerATOMIC_LOAD_AND(SDValue Op, SelectionDAG &DAG) const; Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -798,6 +798,9 @@ } } + if (Subtarget->hasSVE()) + setOperationAction(ISD::VSCALE, MVT::i32, Custom); + setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom); } @@ -3045,6 +3048,8 @@ return LowerATOMIC_LOAD_AND(Op, DAG); case ISD::DYNAMIC_STACKALLOC: return LowerDYNAMIC_STACKALLOC(Op, DAG); + case ISD::VSCALE: + return LowerVSCALE(Op, DAG); } } @@ -8143,6 +8148,16 @@ return DAG.getMergeValues(Ops, dl); } +SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT != MVT::i64 && "Expected illegal VSCALE node"); + + SDLoc DL(Op); + int64_t MulImm = cast(Op.getOperand(0))->getSExtValue(); + return DAG.getZExtOrTrunc(DAG.getVScale(DL, MVT::i64, MulImm), DL, VT); +} + /// getTgtMemIntrinsic - Represent NEON load and store intrinsics as /// MemIntrinsicNodes. The associated MachineMemOperands record the alignment /// specified in the intrinsic calls. Index: lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- lib/Target/AArch64/AArch64SVEInstrInfo.td +++ lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -10,6 +10,18 @@ // //===----------------------------------------------------------------------===// +// SVE CNT/INC/RDVL +def sve_rdvl_imm : ComplexPattern">; +def sve_cnth_imm : ComplexPattern">; +def sve_cntw_imm : ComplexPattern">; +def sve_cntd_imm : ComplexPattern">; + +// SVE DEC +def sve_cnth_imm_neg : ComplexPattern">; +def sve_cntw_imm_neg : ComplexPattern">; +def sve_cntd_imm_neg : ComplexPattern">; + + let Predicates = [HasSVE] in { def RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr">; @@ -1021,6 +1033,20 @@ def : InstAlias<"fcmlt $Zd, $Pg/z, $Zm, $Zn", (FCMGT_PPzZZ_D PPR64:$Zd, PPR3bAny:$Pg, ZPR64:$Zn, ZPR64:$Zm), 0>; + // General case that we ideally never want to match. + def : Pat<(vscale GPR64:$scale), (MADDXrrr (UBFMXri (RDVLI_XI 1), 4, 63), $scale, XZR)>; + + let AddedComplexity = 5 in { + def : Pat<(vscale (sve_rdvl_imm i32:$imm)), (RDVLI_XI $imm)>; + def : Pat<(vscale (sve_cnth_imm i32:$imm)), (CNTH_XPiI 31, $imm)>; + def : Pat<(vscale (sve_cntw_imm i32:$imm)), (CNTW_XPiI 31, $imm)>; + def : Pat<(vscale (sve_cntd_imm i32:$imm)), (CNTD_XPiI 31, $imm)>; + + def : Pat<(vscale (sve_cnth_imm_neg i32:$imm)), (SUBXrs XZR, (CNTH_XPiI 31, $imm), 0)>; + def : Pat<(vscale (sve_cntw_imm_neg i32:$imm)), (SUBXrs XZR, (CNTW_XPiI 31, $imm), 0)>; + def : Pat<(vscale (sve_cntd_imm_neg i32:$imm)), (SUBXrs XZR, (CNTD_XPiI 31, $imm), 0)>; + } + def : Pat<(nxv16i8 (bitconvert (nxv8i16 ZPR:$src))), (nxv16i8 ZPR:$src)>; def : Pat<(nxv16i8 (bitconvert (nxv4i32 ZPR:$src))), (nxv16i8 ZPR:$src)>; def : Pat<(nxv16i8 (bitconvert (nxv2i64 ZPR:$src))), (nxv16i8 ZPR:$src)>; Index: test/CodeGen/AArch64/sve-vscale.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/sve-vscale.ll @@ -0,0 +1,128 @@ +; RUN: llc -mtriple aarch64 -mattr=+sve -asm-verbose=0 < %s | FileCheck %s + +; +; RDVL +; + +; CHECK-LABEL: rdvl: +; CHECK: rdvl x0, #1 +; CHECK-NEXT: ret +define i32 @rdvl() nounwind { + %1 = mul i32 vscale, 16 + ret i32 %1 +} + +; CHECK-LABEL: rdvl_3: +; CHECK: rdvl [[VL_B:x[0-9]+]], #1 +; CHECK-NEXT: lsr [[VL_Q:x[0-9]+]], [[VL_B]], #4 +; CHECK-NEXT: mov w[[MUL:[0-9]+]], #3 +; CHECK-NEXT: mul x0, [[VL_Q]], x[[MUL]] +; CHECK-NEXT: ret +define i32 @rdvl_3() nounwind { + %1 = mul i32 vscale, 3 + ret i32 %1 +} + + +; CHECK-LABEL: rdvl_min: +; CHECK: rdvl x0, #-32 +; CHECK-NEXT: ret +define i32 @rdvl_min() nounwind { + %1 = mul i32 vscale, -512 + ret i32 %1 +} + +; CHECK-LABEL: rdvl_max: +; CHECK: rdvl x0, #31 +; CHECK-NEXT: ret +define i32 @rdvl_max() nounwind { + %1 = mul i32 vscale, 496 + ret i32 %1 +} + +; +; CNTH +; + +; CHECK-LABEL: cnth: +; CHECK: cnth x0{{$}} +; CHECK-NEXT: ret +define i32 @cnth() nounwind { + %1 = mul i32 vscale, 8 + ret i32 %1 +} + +; CHECK-LABEL: cnth_max: +; CHECK: cnth x0, all, mul #15 +; CHECK-NEXT: ret +define i32 @cnth_max() nounwind { + %1 = mul i32 vscale, 120 + ret i32 %1 +} + +; CHECK-LABEL: cnth_neg: +; CHECK: cnth [[CNT:x[0-9]+]] +; CHECK: neg x0, [[CNT]] +; CHECK-NEXT: ret +define i32 @cnth_neg() nounwind { + %1 = mul i32 vscale, -8 + ret i32 %1 +} + +; +; CNTW +; + +; CHECK-LABEL: cntw: +; CHECK: cntw x0{{$}} +; CHECK-NEXT: ret +define i32 @cntw() nounwind { + %1 = mul i32 vscale, 4 + ret i32 %1 +} + +; CHECK-LABEL: cntw_max: +; CHECK: cntw x0, all, mul #15 +; CHECK-NEXT: ret +define i32 @cntw_max() nounwind { + %1 = mul i32 vscale, 60 + ret i32 %1 +} + +; CHECK-LABEL: cntw_neg: +; CHECK: cntw [[CNT:x[0-9]+]] +; CHECK: neg x0, [[CNT]] +; CHECK-NEXT: ret +define i32 @cntw_neg() nounwind { + %1 = mul i32 vscale, -4 + ret i32 %1 +} + +; +; CNTD +; + +; CHECK-LABEL: cntd: +; CHECK: cntd x0{{$}} +; CHECK-NEXT: ret +define i32 @cntd() nounwind { + %1 = mul i32 vscale, 2 + ret i32 %1 +} + +; CHECK-LABEL: cntd_max: +; CHECK: cntd x0, all, mul #15 +; CHECK-NEXT: ret +define i32 @cntd_max() nounwind { + %1 = mul i32 vscale, 30 + ret i32 %1 +} + +; CHECK-LABEL: cntd_neg: +; CHECK: cntd [[CNT:x[0-9]+]] +; CHECK: neg x0, [[CNT]] +; CHECK-NEXT: ret +define i32 @cntd_neg() nounwind { + %1 = mul i32 vscale, -2 + ret i32 %1 +}