diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -109,10 +109,11 @@ VLEFF, VLEFF_MASK, // Matches the semantics of vslideup/vslidedown. The first operand is the - // pass-thru operand, the second is the source vector, and the third is the - // XLenVT index (either constant or non-constant). - VSLIDEUP, - VSLIDEDOWN, + // pass-thru operand, the second is the source vector, the third is the + // XLenVT index (either constant or non-constant), the fourth is the mask + // and the fifth the VL. + VSLIDEUP_VL, + VSLIDEDOWN_VL, // Matches the semantics of the vid.v instruction, with a mask and VL // operand. VID_VL, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -816,6 +816,30 @@ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero); } +// Gets the two common "VL" operands: an all-ones mask and the vector length. +// VecVT is a vector type, either fixed-length or scalable, and ContainerVT is +// the vector type that it is contained in. +static std::pair +getDefaultVLOps(EVT VecVT, EVT ContainerVT, SDLoc DL, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(ContainerVT.isScalableVector() && "Expecting scalable container type"); + MVT XLenVT = Subtarget.getXLenVT(); + SDValue VL = VecVT.isFixedLengthVector() + ? DAG.getConstant(VecVT.getVectorNumElements(), DL, XLenVT) + : DAG.getRegister(RISCV::X0, XLenVT); + MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + return {Mask, VL}; +} + +// As above but assuming the given type is a scalable vector type. +static std::pair +getDefaultScalableVLOps(EVT VecVT, SDLoc DL, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(VecVT.isScalableVector() && "Expecting a scalable vector"); + return getDefaultVLOps(VecVT, VecVT, DL, DAG, Subtarget); +} + static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); @@ -824,8 +848,8 @@ MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); SDLoc DL(Op); - SDValue VL = - DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT()); + SDValue Mask, VL; + std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); if (SDValue Splat = cast(Op)->getSplatValue()) { unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL @@ -844,8 +868,6 @@ Op.getConstantOperandVal(i) == i); if (IsVID) { - MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); - SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, ContainerVT, Mask, VL); return convertFromScalableVector(VT, VID, DAG, Subtarget); } @@ -1699,13 +1721,16 @@ if (Subtarget.is64Bit() || VecVT.getVectorElementType() != MVT::i64) { if (isNullConstant(Idx)) return Op; - SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT, - DAG.getUNDEF(VecVT), Vec, Idx); + SDValue Mask, VL; + std::tie(Mask, VL) = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget); + SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VecVT, + DAG.getUNDEF(VecVT), Vec, Idx, Mask, VL); SDValue InsertElt0 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, Slidedown, Val, DAG.getConstant(0, DL, Subtarget.getXLenVT())); - return DAG.getNode(RISCVISD::VSLIDEUP, DL, VecVT, Vec, InsertElt0, Idx); + return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VecVT, Vec, InsertElt0, Idx, + Mask, VL); } // Custom-legalize INSERT_VECTOR_ELT where XLEN; } - -def SDTRVVSlide : SDTypeProfile<1, 3, [ - SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT> -]>; - -def riscv_slideup : SDNode<"RISCVISD::VSLIDEUP", SDTRVVSlide, []>; -def riscv_slidedown : SDNode<"RISCVISD::VSLIDEDOWN", SDTRVVSlide, []>; - -let Predicates = [HasStdExtV] in { - -foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in { - def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3), - (vti.Vector vti.RegClass:$rs1), - uimm5:$rs2)), - (!cast("PseudoVSLIDEUP_VI_"#vti.LMul.MX) - vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2, - vti.AVL, vti.SEW)>; - - def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3), - (vti.Vector vti.RegClass:$rs1), - GPR:$rs2)), - (!cast("PseudoVSLIDEUP_VX_"#vti.LMul.MX) - vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2, - vti.AVL, vti.SEW)>; - - def : Pat<(vti.Vector (riscv_slidedown (vti.Vector vti.RegClass:$rs3), - (vti.Vector vti.RegClass:$rs1), - uimm5:$rs2)), - (!cast("PseudoVSLIDEDOWN_VI_"#vti.LMul.MX) - vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2, - vti.AVL, vti.SEW)>; - - def : Pat<(vti.Vector (riscv_slidedown (vti.Vector vti.RegClass:$rs3), - (vti.Vector vti.RegClass:$rs1), - GPR:$rs2)), - (!cast("PseudoVSLIDEDOWN_VX_"#vti.LMul.MX) - vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2, - vti.AVL, vti.SEW)>; -} -} // Predicates = [HasStdExtV] diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -219,6 +219,14 @@ [SDTCisVec<0>, SDTCisVec<1>, SDTCVecEltisVT<1, i1>, SDTCisSameNumEltsAs<0, 1>, SDTCisVT<2, XLenVT>]>, []>; +def SDTRVVSlide : SDTypeProfile<1, 5, [ + SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT>, + SDTCVecEltisVT<4, i1>, SDTCisSameNumEltsAs<0, 4>, SDTCisVT<5, XLenVT> +]>; + +def riscv_slideup_vl : SDNode<"RISCVISD::VSLIDEUP_VL", SDTRVVSlide, []>; +def riscv_slidedown_vl : SDNode<"RISCVISD::VSLIDEDOWN_VL", SDTRVVSlide, []>; + let Predicates = [HasStdExtV] in { foreach vti = AllIntegerVectors in @@ -226,4 +234,38 @@ (XLenVT (VLOp GPR:$vl)))), (!cast("PseudoVID_V_"#vti.LMul.MX) GPR:$vl, vti.SEW)>; +foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in { + def : Pat<(vti.Vector (riscv_slideup_vl (vti.Vector vti.RegClass:$rs3), + (vti.Vector vti.RegClass:$rs1), + uimm5:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVSLIDEUP_VI_"#vti.LMul.MX) + vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2, + GPR:$vl, vti.SEW)>; + + def : Pat<(vti.Vector (riscv_slideup_vl (vti.Vector vti.RegClass:$rs3), + (vti.Vector vti.RegClass:$rs1), + GPR:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVSLIDEUP_VX_"#vti.LMul.MX) + vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2, + GPR:$vl, vti.SEW)>; + + def : Pat<(vti.Vector (riscv_slidedown_vl (vti.Vector vti.RegClass:$rs3), + (vti.Vector vti.RegClass:$rs1), + uimm5:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVSLIDEDOWN_VI_"#vti.LMul.MX) + vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2, + GPR:$vl, vti.SEW)>; + + def : Pat<(vti.Vector (riscv_slidedown_vl (vti.Vector vti.RegClass:$rs3), + (vti.Vector vti.RegClass:$rs1), + GPR:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVSLIDEDOWN_VX_"#vti.LMul.MX) + vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2, + GPR:$vl, vti.SEW)>; +} + } // Predicates = [HasStdExtV]