diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1954,8 +1954,8 @@ SDValue Mask, VL; std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); - SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec, - Idx, Mask, VL); + SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), Vec, Idx, Mask, VL); if (!VT.isFixedLengthVector()) return Gather; @@ -2578,9 +2578,9 @@ V1 = convertToScalableVector(ContainerVT, V1, DAG, Subtarget); assert(Lane < (int)NumElts && "Unexpected lane!"); - SDValue Gather = - DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, V1, - DAG.getConstant(Lane, DL, XLenVT), TrueMask, VL); + SDValue Gather = DAG.getNode( + RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT), + V1, DAG.getConstant(Lane, DL, XLenVT), TrueMask, VL); return convertFromScalableVector(VT, Gather, DAG, Subtarget); } } @@ -2790,16 +2790,17 @@ // that's beneficial. if (LHSIndexCounts.size() == 1) { int SplatIndex = LHSIndexCounts.begin()->getFirst(); - Gather = - DAG.getNode(GatherVXOpc, DL, ContainerVT, V1, - DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL); + Gather = DAG.getNode( + GatherVXOpc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), V1, + DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL); } else { SDValue LHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesLHS); LHSIndices = convertToScalableVector(IndexContainerVT, LHSIndices, DAG, Subtarget); - Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, V1, LHSIndices, - TrueMask, VL); + Gather = + DAG.getNode(GatherVVOpc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), + V1, LHSIndices, TrueMask, VL); } } @@ -2807,27 +2808,26 @@ // additional vrgather. if (!V2.isUndef()) { V2 = convertToScalableVector(ContainerVT, V2, DAG, Subtarget); + + MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1); + SelectMask = + convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget); + // If only one index is used, we can use a "splat" vrgather. // TODO: We can splat the most-common index and fix-up any stragglers, if // that's beneficial. if (RHSIndexCounts.size() == 1) { int SplatIndex = RHSIndexCounts.begin()->getFirst(); - V2 = DAG.getNode(GatherVXOpc, DL, ContainerVT, V2, - DAG.getConstant(SplatIndex, DL, XLenVT), TrueMask, VL); + Gather = + DAG.getNode(GatherVXOpc, DL, ContainerVT, Gather, V2, + DAG.getConstant(SplatIndex, DL, XLenVT), SelectMask, VL); } else { SDValue RHSIndices = DAG.getBuildVector(IndexVT, DL, GatherIndicesRHS); RHSIndices = convertToScalableVector(IndexContainerVT, RHSIndices, DAG, Subtarget); - V2 = DAG.getNode(GatherVVOpc, DL, ContainerVT, V2, RHSIndices, TrueMask, - VL); + Gather = DAG.getNode(GatherVVOpc, DL, ContainerVT, Gather, V2, RHSIndices, + SelectMask, VL); } - - MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1); - SelectMask = - convertToScalableVector(MaskContainerVT, SelectMask, DAG, Subtarget); - - Gather = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, SelectMask, V2, - Gather, VL); } return convertFromScalableVector(VT, Gather, DAG, Subtarget); @@ -5688,7 +5688,8 @@ SDValue Indices = DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID, Mask, VL); - return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices, Mask, VL); + return DAG.getNode(GatherOpc, DL, VecVT, DAG.getUNDEF(VecVT), + Op.getOperand(0), Indices, Mask, VL); } SDValue RISCVTargetLowering::lowerVECTOR_SPLICE(SDValue Op, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -143,30 +143,33 @@ SDTCisVT<5, XLenVT>]>>; def riscv_vrgather_vx_vl : SDNode<"RISCVISD::VRGATHER_VX_VL", - SDTypeProfile<1, 4, [SDTCisVec<0>, + SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisSameAs<0, 1>, - SDTCisVT<2, XLenVT>, - SDTCVecEltisVT<3, i1>, - SDTCisSameNumEltsAs<0, 3>, - SDTCisVT<4, XLenVT>]>>; + SDTCisSameAs<0, 2>, + SDTCisVT<3, XLenVT>, + SDTCVecEltisVT<4, i1>, + SDTCisSameNumEltsAs<0, 4>, + SDTCisVT<5, XLenVT>]>>; def riscv_vrgather_vv_vl : SDNode<"RISCVISD::VRGATHER_VV_VL", - SDTypeProfile<1, 4, [SDTCisVec<0>, + SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisSameAs<0, 1>, - SDTCisInt<2>, - SDTCisSameNumEltsAs<0, 2>, - SDTCisSameSizeAs<0, 2>, - SDTCVecEltisVT<3, i1>, + SDTCisSameAs<0, 2>, + SDTCisInt<3>, SDTCisSameNumEltsAs<0, 3>, - SDTCisVT<4, XLenVT>]>>; + SDTCisSameSizeAs<0, 3>, + SDTCVecEltisVT<4, i1>, + SDTCisSameNumEltsAs<0, 4>, + SDTCisVT<5, XLenVT>]>>; def riscv_vrgatherei16_vv_vl : SDNode<"RISCVISD::VRGATHEREI16_VV_VL", - SDTypeProfile<1, 4, [SDTCisVec<0>, + SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisSameAs<0, 1>, - SDTCisInt<2>, - SDTCVecEltisVT<2, i16>, - SDTCisSameNumEltsAs<0, 2>, - SDTCVecEltisVT<3, i1>, + SDTCisSameAs<0, 2>, + SDTCisInt<3>, + SDTCVecEltisVT<3, i16>, SDTCisSameNumEltsAs<0, 3>, - SDTCisVT<4, XLenVT>]>>; + SDTCVecEltisVT<4, i1>, + SDTCisSameNumEltsAs<0, 4>, + SDTCisVT<5, XLenVT>]>>; def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>, @@ -1835,43 +1838,40 @@ (!cast("PseudoVMV_S_X_"#vti.LMul.MX) vti.RegClass:$merge, (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$rs2, + def : Pat<(vti.Vector (riscv_vrgather_vv_vl (vti.Vector srcvalue), + vti.RegClass:$rs2, (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)), (!cast("PseudoVRGATHER_VV_"# vti.LMul.MX) vti.RegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1, + def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue), + vti.RegClass:$rs2, GPR:$rs1, (vti.Mask true_mask), VLOpFrag)), (!cast("PseudoVRGATHER_VX_"# vti.LMul.MX) vti.RegClass:$rs2, GPR:$rs1, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, uimm5:$imm, + def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue), + vti.RegClass:$rs2, uimm5:$imm, (vti.Mask true_mask), VLOpFrag)), (!cast("PseudoVRGATHER_VI_"# vti.LMul.MX) vti.RegClass:$rs2, uimm5:$imm, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0), - (riscv_vrgather_vv_vl - vti.RegClass:$rs2, - vti.RegClass:$rs1, - (vti.Mask true_mask), - VLOpFrag), - vti.RegClass:$merge, - VLOpFrag)), + def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$merge, + vti.RegClass:$rs2, + vti.RegClass:$rs1, + (vti.Mask V0), + VLOpFrag)), (!cast("PseudoVRGATHER_VV_"# vti.LMul.MX#"_MASK") vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0), - (riscv_vrgather_vx_vl - vti.RegClass:$rs2, - uimm5:$imm, - (vti.Mask true_mask), - VLOpFrag), - vti.RegClass:$merge, - VLOpFrag)), + def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$merge, + vti.RegClass:$rs2, + uimm5:$imm, + (vti.Mask V0), + VLOpFrag)), (!cast("PseudoVRGATHER_VI_"# vti.LMul.MX#"_MASK") vti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$imm, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; @@ -1884,21 +1884,20 @@ defvar emul_str = octuple_to_str.ret; defvar ivti = !cast("VI16" # emul_str); defvar inst = "PseudoVRGATHEREI16_VV_" # vti.LMul.MX # "_" # emul_str; - def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl vti.RegClass:$rs2, + def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl (vti.Vector srcvalue), + vti.RegClass:$rs2, (ivti.Vector ivti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)), (!cast(inst) vti.RegClass:$rs2, ivti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0), - (riscv_vrgatherei16_vv_vl - vti.RegClass:$rs2, - (ivti.Vector ivti.RegClass:$rs1), - (vti.Mask true_mask), - VLOpFrag), - vti.RegClass:$merge, - VLOpFrag)), + def : Pat<(vti.Vector + (riscv_vrgatherei16_vv_vl vti.RegClass:$merge, + vti.RegClass:$rs2, + (ivti.Vector ivti.RegClass:$rs1), + (vti.Mask V0), + VLOpFrag)), (!cast(inst#"_MASK") vti.RegClass:$merge, vti.RegClass:$rs2, ivti.RegClass:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; @@ -1923,43 +1922,42 @@ vti.RegClass:$merge, (vti.Scalar vti.ScalarRegClass:$rs1), GPR:$vl, vti.Log2SEW)>; defvar ivti = GetIntVTypeInfo.Vti; - def : Pat<(vti.Vector (riscv_vrgather_vv_vl vti.RegClass:$rs2, + def : Pat<(vti.Vector (riscv_vrgather_vv_vl (vti.Vector srcvalue), + vti.RegClass:$rs2, (ivti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)), (!cast("PseudoVRGATHER_VV_"# vti.LMul.MX) vti.RegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, GPR:$rs1, + def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue), + vti.RegClass:$rs2, GPR:$rs1, (vti.Mask true_mask), VLOpFrag)), (!cast("PseudoVRGATHER_VX_"# vti.LMul.MX) vti.RegClass:$rs2, GPR:$rs1, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vrgather_vx_vl vti.RegClass:$rs2, uimm5:$imm, + def : Pat<(vti.Vector (riscv_vrgather_vx_vl (vti.Vector srcvalue), + vti.RegClass:$rs2, uimm5:$imm, (vti.Mask true_mask), VLOpFrag)), (!cast("PseudoVRGATHER_VI_"# vti.LMul.MX) vti.RegClass:$rs2, uimm5:$imm, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0), - (riscv_vrgather_vv_vl - vti.RegClass:$rs2, - (ivti.Vector vti.RegClass:$rs1), - (vti.Mask true_mask), - VLOpFrag), - vti.RegClass:$merge, - VLOpFrag)), + def : Pat<(vti.Vector + (riscv_vrgather_vv_vl vti.RegClass:$merge, + vti.RegClass:$rs2, + (ivti.Vector vti.RegClass:$rs1), + (vti.Mask V0), + VLOpFrag)), (!cast("PseudoVRGATHER_VV_"# vti.LMul.MX#"_MASK") vti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0), - (riscv_vrgather_vx_vl - vti.RegClass:$rs2, - uimm5:$imm, - (vti.Mask true_mask), - VLOpFrag), - vti.RegClass:$merge, - VLOpFrag)), + def : Pat<(vti.Vector + (riscv_vrgather_vx_vl vti.RegClass:$merge, + vti.RegClass:$rs2, + uimm5:$imm, + (vti.Mask V0), + VLOpFrag)), (!cast("PseudoVRGATHER_VI_"# vti.LMul.MX#"_MASK") vti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$imm, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; @@ -1971,21 +1969,20 @@ defvar emul_str = octuple_to_str.ret; defvar ivti = !cast("VI16" # emul_str); defvar inst = "PseudoVRGATHEREI16_VV_" # vti.LMul.MX # "_" # emul_str; - def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl vti.RegClass:$rs2, + def : Pat<(vti.Vector (riscv_vrgatherei16_vv_vl (vti.Vector srcvalue), + vti.RegClass:$rs2, (ivti.Vector ivti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)), (!cast(inst) vti.RegClass:$rs2, ivti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>; - def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0), - (riscv_vrgatherei16_vv_vl - vti.RegClass:$rs2, - (ivti.Vector ivti.RegClass:$rs1), - (vti.Mask true_mask), - VLOpFrag), - vti.RegClass:$merge, - VLOpFrag)), + def : Pat<(vti.Vector + (riscv_vrgatherei16_vv_vl vti.RegClass:$merge, + vti.RegClass:$rs2, + (ivti.Vector ivti.RegClass:$rs1), + (vti.Mask V0), + VLOpFrag)), (!cast(inst#"_MASK") vti.RegClass:$merge, vti.RegClass:$rs2, ivti.RegClass:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;