diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -109,10 +109,11 @@ VLEFF, VLEFF_MASK, // Matches the semantics of vslideup/vslidedown. The first operand is the - // pass-thru operand, the second is the source vector, and the third is the - // XLenVT index (either constant or non-constant). - VSLIDEUP, - VSLIDEDOWN, + // pass-thru operand, the second is the source vector, the third is the + // XLenVT index (either constant or non-constant), the fourth is the mask + // and the fifth the VL. + VSLIDEUP_VL, + VSLIDEDOWN_VL, // Matches the semantics of the vid.v instruction, with a mask and VL // operand. VID_VL, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -832,6 +832,30 @@ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero); } +// Gets the two common "VL" operands: an all-ones mask and the vector length. +// VecVT is a vector type, either fixed-length or scalable, and ContainerVT is +// the vector type that it is contained in. +static std::pair +getDefaultVLOps(MVT VecVT, MVT ContainerVT, SDLoc DL, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(ContainerVT.isScalableVector() && "Expecting scalable container type"); + MVT XLenVT = Subtarget.getXLenVT(); + SDValue VL = VecVT.isFixedLengthVector() + ? DAG.getConstant(VecVT.getVectorNumElements(), DL, XLenVT) + : DAG.getRegister(RISCV::X0, XLenVT); + MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + return {Mask, VL}; +} + +// As above but assuming the given type is a scalable vector type. +static std::pair +getDefaultScalableVLOps(MVT VecVT, SDLoc DL, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(VecVT.isScalableVector() && "Expecting a scalable vector"); + return getDefaultVLOps(VecVT, VecVT, DL, DAG, Subtarget); +} + static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { MVT VT = Op.getSimpleValueType(); @@ -840,8 +864,8 @@ MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); SDLoc DL(Op); - SDValue VL = - DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT()); + SDValue Mask, VL; + std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); if (VT.getVectorElementType() == MVT::i1) { if (ISD::isBuildVectorAllZeros(Op.getNode())) { @@ -874,8 +898,6 @@ Op.getConstantOperandVal(i) == i); if (IsVID) { - MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); - SDValue Mask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, ContainerVT, Mask, VL); return convertFromScalableVector(VT, VID, DAG, Subtarget); } @@ -1757,7 +1779,7 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); - EVT VecVT = Op.getValueType(); + MVT VecVT = Op.getSimpleValueType(); SDValue Vec = Op.getOperand(0); SDValue Val = Op.getOperand(1); SDValue Idx = Op.getOperand(2); @@ -1769,13 +1791,16 @@ if (Subtarget.is64Bit() || VecVT.getVectorElementType() != MVT::i64) { if (isNullConstant(Idx)) return Op; - SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT, - DAG.getUNDEF(VecVT), Vec, Idx); + SDValue Mask, VL; + std::tie(Mask, VL) = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget); + SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VecVT, + DAG.getUNDEF(VecVT), Vec, Idx, Mask, VL); SDValue InsertElt0 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, Slidedown, Val, DAG.getConstant(0, DL, Subtarget.getXLenVT())); - return DAG.getNode(RISCVISD::VSLIDEUP, DL, VecVT, Vec, InsertElt0, Idx); + return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VecVT, Vec, InsertElt0, Idx, + Mask, VL); } // Custom-legalize INSERT_VECTOR_ELT where XLENgetOperand(0); SDValue Idx = N->getOperand(1); - EVT VecVT = Vec.getValueType(); + MVT VecVT = Vec.getSimpleValueType(); assert(!Subtarget.is64Bit() && N->getValueType(0) == MVT::i64 && VecVT.getVectorElementType() == MVT::i64 && "Unexpected EXTRACT_VECTOR_ELT legalization"); SDValue Slidedown = Vec; + MVT XLenVT = Subtarget.getXLenVT(); // Unless the index is known to be 0, we must slide the vector down to get // the desired element into index 0. - if (!isNullConstant(Idx)) - Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT, - DAG.getUNDEF(VecVT), Vec, Idx); + if (!isNullConstant(Idx)) { + SDValue Mask, VL; + std::tie(Mask, VL) = getDefaultScalableVLOps(VecVT, DL, DAG, Subtarget); + Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, VecVT, + DAG.getUNDEF(VecVT), Vec, Idx, Mask, VL); + } - MVT XLenVT = Subtarget.getXLenVT(); // Extract the lower XLEN bits of the correct vector element. SDValue EltLo = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Slidedown, Idx); @@ -4677,8 +4704,8 @@ NODE_NAME_CASE(TRUNCATE_VECTOR) NODE_NAME_CASE(VLEFF) NODE_NAME_CASE(VLEFF_MASK) - NODE_NAME_CASE(VSLIDEUP) - NODE_NAME_CASE(VSLIDEDOWN) + NODE_NAME_CASE(VSLIDEUP_VL) + NODE_NAME_CASE(VSLIDEDOWN_VL) NODE_NAME_CASE(VID_VL) NODE_NAME_CASE(VFNCVT_ROD) NODE_NAME_CASE(VECREDUCE_ADD) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -823,43 +823,3 @@ (vti.Scalar vti.ScalarRegClass:$rs1), vti.AVL, vti.SEW)>; } - -def SDTRVVSlide : SDTypeProfile<1, 3, [ - SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT> -]>; - -def riscv_slideup : SDNode<"RISCVISD::VSLIDEUP", SDTRVVSlide, []>; -def riscv_slidedown : SDNode<"RISCVISD::VSLIDEDOWN", SDTRVVSlide, []>; - -let Predicates = [HasStdExtV] in { - -foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in { - def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3), - (vti.Vector vti.RegClass:$rs1), - uimm5:$rs2)), - (!cast("PseudoVSLIDEUP_VI_"#vti.LMul.MX) - vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2, - vti.AVL, vti.SEW)>; - - def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3), - (vti.Vector vti.RegClass:$rs1), - GPR:$rs2)), - (!cast("PseudoVSLIDEUP_VX_"#vti.LMul.MX) - vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2, - vti.AVL, vti.SEW)>; - - def : Pat<(vti.Vector (riscv_slidedown (vti.Vector vti.RegClass:$rs3), - (vti.Vector vti.RegClass:$rs1), - uimm5:$rs2)), - (!cast("PseudoVSLIDEDOWN_VI_"#vti.LMul.MX) - vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2, - vti.AVL, vti.SEW)>; - - def : Pat<(vti.Vector (riscv_slidedown (vti.Vector vti.RegClass:$rs3), - (vti.Vector vti.RegClass:$rs1), - GPR:$rs2)), - (!cast("PseudoVSLIDEDOWN_VX_"#vti.LMul.MX) - vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2, - vti.AVL, vti.SEW)>; -} -} // Predicates = [HasStdExtV] diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -343,6 +343,14 @@ [SDTCisVec<0>, SDTCisVec<1>, SDTCVecEltisVT<1, i1>, SDTCisSameNumEltsAs<0, 1>, SDTCisVT<2, XLenVT>]>, []>; +def SDTRVVSlide : SDTypeProfile<1, 5, [ + SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT>, + SDTCVecEltisVT<4, i1>, SDTCisSameNumEltsAs<0, 4>, SDTCisVT<5, XLenVT> +]>; + +def riscv_slideup_vl : SDNode<"RISCVISD::VSLIDEUP_VL", SDTRVVSlide, []>; +def riscv_slidedown_vl : SDNode<"RISCVISD::VSLIDEDOWN_VL", SDTRVVSlide, []>; + let Predicates = [HasStdExtV] in { foreach vti = AllIntegerVectors in @@ -350,4 +358,38 @@ (XLenVT (VLOp GPR:$vl)))), (!cast("PseudoVID_V_"#vti.LMul.MX) GPR:$vl, vti.SEW)>; +foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in { + def : Pat<(vti.Vector (riscv_slideup_vl (vti.Vector vti.RegClass:$rs3), + (vti.Vector vti.RegClass:$rs1), + uimm5:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVSLIDEUP_VI_"#vti.LMul.MX) + vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2, + GPR:$vl, vti.SEW)>; + + def : Pat<(vti.Vector (riscv_slideup_vl (vti.Vector vti.RegClass:$rs3), + (vti.Vector vti.RegClass:$rs1), + GPR:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVSLIDEUP_VX_"#vti.LMul.MX) + vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2, + GPR:$vl, vti.SEW)>; + + def : Pat<(vti.Vector (riscv_slidedown_vl (vti.Vector vti.RegClass:$rs3), + (vti.Vector vti.RegClass:$rs1), + uimm5:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVSLIDEDOWN_VI_"#vti.LMul.MX) + vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2, + GPR:$vl, vti.SEW)>; + + def : Pat<(vti.Vector (riscv_slidedown_vl (vti.Vector vti.RegClass:$rs3), + (vti.Vector vti.RegClass:$rs1), + GPR:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVSLIDEDOWN_VX_"#vti.LMul.MX) + vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2, + GPR:$vl, vti.SEW)>; +} + } // Predicates = [HasStdExtV]