diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -213,8 +213,8 @@ VECREDUCE_FMIN_VL, VECREDUCE_FMAX_VL, - // Vector binary and unary ops with a mask as a third operand, and VL as a - // fourth operand. + // Vector binary ops with a mask as a third operand, a merge value + // as a fourth operand, and VL as a fifth operand. // FIXME: Can we replace these with ISD::VP_*? ADD_VL, AND_VL, @@ -229,32 +229,30 @@ UDIV_VL, UREM_VL, XOR_VL, + SMIN_VL, + SMAX_VL, + UMIN_VL, + UMAX_VL, SADDSAT_VL, UADDSAT_VL, SSUBSAT_VL, USUBSAT_VL, + MULHS_VL, + MULHU_VL, FADD_VL, FSUB_VL, FMUL_VL, FDIV_VL, + FMINNUM_VL, + FMAXNUM_VL, + + // Vector FMA ops with a mask as a second operand and VL as a third operand. FNEG_VL, FABS_VL, FSQRT_VL, - VFMADD_VL, - VFNMADD_VL, - VFMSUB_VL, - VFNMSUB_VL, FCOPYSIGN_VL, - SMIN_VL, - SMAX_VL, - UMIN_VL, - UMAX_VL, - FMINNUM_VL, - FMAXNUM_VL, - MULHS_VL, - MULHU_VL, FP_TO_SINT_VL, FP_TO_UINT_VL, SINT_TO_FP_VL, @@ -262,7 +260,14 @@ FP_ROUND_VL, FP_EXTEND_VL, - // Widening instructions + // Vector FMA ops with a mask as a fourth operand and VL as a fifth operand. + VFMADD_VL, + VFNMADD_VL, + VFMSUB_VL, + VFNMSUB_VL, + + // Widening instructions with a mask a third operand, a merge value as a + // fourth operand, and VL as a fifth operand. VWMUL_VL, VWMULU_VL, VWMULSU_VL, @@ -679,8 +684,9 @@ SDValue lowerFixedLengthVectorSelectToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG, unsigned NewOpc, - bool HasMask = true) const; - SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc) const; + bool HasMask = true, bool HasMergeOp = false) const; + SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc, + bool HasMergeOp = false) const; SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG, unsigned MaskOpc, unsigned VecOpc) const; SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1929,8 +1929,8 @@ DAG.getUNDEF(ContainerVT), SplatVal, VL); // Add the adjustment. - SDValue Adjust = - DAG.getNode(RISCVISD::FADD_VL, DL, ContainerVT, Abs, Splat, Mask, VL); + SDValue Adjust = DAG.getNode(RISCVISD::FADD_VL, DL, ContainerVT, Abs, Splat, + Mask, DAG.getUNDEF(ContainerVT), VL); // Truncate to integer and convert back to fp. MVT IntVT = ContainerVT.changeVectorElementTypeToInteger(); @@ -2798,19 +2798,21 @@ TrueMask = getAllOnesMask(HalfContainerVT, VL, DL, DAG); // Widen V1 and V2 with 0s and add one copy of V2 to V1. - SDValue Add = DAG.getNode(RISCVISD::VWADDU_VL, DL, WideIntContainerVT, V1, - V2, TrueMask, VL); + SDValue Add = + DAG.getNode(RISCVISD::VWADDU_VL, DL, WideIntContainerVT, V1, V2, + TrueMask, DAG.getUNDEF(WideIntContainerVT), VL); // Create 2^eltbits - 1 copies of V2 by multiplying by the largest integer. SDValue Multiplier = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntHalfVT, DAG.getUNDEF(IntHalfVT), DAG.getAllOnesConstant(DL, XLenVT)); - SDValue WidenMul = DAG.getNode(RISCVISD::VWMULU_VL, DL, WideIntContainerVT, - V2, Multiplier, TrueMask, VL); + SDValue WidenMul = + DAG.getNode(RISCVISD::VWMULU_VL, DL, WideIntContainerVT, V2, Multiplier, + TrueMask, DAG.getUNDEF(WideIntContainerVT), VL); // Add the new copies to our previous addition giving us 2^eltbits copies of // V2. This is equivalent to shifting V2 left by eltbits. This should // combine with the vwmulu.vv above to form vwmaccu.vv. Add = DAG.getNode(RISCVISD::ADD_VL, DL, WideIntContainerVT, Add, WidenMul, - TrueMask, VL); + TrueMask, DAG.getUNDEF(WideIntContainerVT), VL); // Cast back to ContainerVT. We need to re-create a new ContainerVT in case // WideIntContainerVT is a larger fractional LMUL than implied by the fixed // vector VT. @@ -3534,15 +3536,15 @@ case ISD::SETCC: return lowerFixedLengthVectorSetccToRVV(Op, DAG); case ISD::ADD: - return lowerToScalableOp(Op, DAG, RISCVISD::ADD_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::ADD_VL, true, true); case ISD::SUB: - return lowerToScalableOp(Op, DAG, RISCVISD::SUB_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::SUB_VL, true, true); case ISD::MUL: - return lowerToScalableOp(Op, DAG, RISCVISD::MUL_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::MUL_VL, true, true); case ISD::MULHS: - return lowerToScalableOp(Op, DAG, RISCVISD::MULHS_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::MULHS_VL, true, true); case ISD::MULHU: - return lowerToScalableOp(Op, DAG, RISCVISD::MULHU_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::MULHU_VL, true, true); case ISD::AND: return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMAND_VL, RISCVISD::AND_VL); @@ -3553,13 +3555,13 @@ return lowerFixedLengthVectorLogicOpToRVV(Op, DAG, RISCVISD::VMXOR_VL, RISCVISD::XOR_VL); case ISD::SDIV: - return lowerToScalableOp(Op, DAG, RISCVISD::SDIV_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::SDIV_VL, true, true); case ISD::SREM: - return lowerToScalableOp(Op, DAG, RISCVISD::SREM_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::SREM_VL, true, true); case ISD::UDIV: - return lowerToScalableOp(Op, DAG, RISCVISD::UDIV_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::UDIV_VL, true, true); case ISD::UREM: - return lowerToScalableOp(Op, DAG, RISCVISD::UREM_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::UREM_VL, true, true); case ISD::SHL: case ISD::SRA: case ISD::SRL: @@ -3570,21 +3572,21 @@ "Unexpected custom legalisation"); return SDValue(); case ISD::SADDSAT: - return lowerToScalableOp(Op, DAG, RISCVISD::SADDSAT_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::SADDSAT_VL, true, true); case ISD::UADDSAT: - return lowerToScalableOp(Op, DAG, RISCVISD::UADDSAT_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::UADDSAT_VL, true, true); case ISD::SSUBSAT: - return lowerToScalableOp(Op, DAG, RISCVISD::SSUBSAT_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::SSUBSAT_VL, true, true); case ISD::USUBSAT: - return lowerToScalableOp(Op, DAG, RISCVISD::USUBSAT_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::USUBSAT_VL, true, true); case ISD::FADD: - return lowerToScalableOp(Op, DAG, RISCVISD::FADD_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::FADD_VL, true, true); case ISD::FSUB: - return lowerToScalableOp(Op, DAG, RISCVISD::FSUB_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::FSUB_VL, true, true); case ISD::FMUL: - return lowerToScalableOp(Op, DAG, RISCVISD::FMUL_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::FMUL_VL, true, true); case ISD::FDIV: - return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL, true, true); case ISD::FNEG: return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL); case ISD::FABS: @@ -3594,17 +3596,17 @@ case ISD::FMA: return lowerToScalableOp(Op, DAG, RISCVISD::VFMADD_VL); case ISD::SMIN: - return lowerToScalableOp(Op, DAG, RISCVISD::SMIN_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::SMIN_VL, true, true); case ISD::SMAX: - return lowerToScalableOp(Op, DAG, RISCVISD::SMAX_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::SMAX_VL, true, true); case ISD::UMIN: - return lowerToScalableOp(Op, DAG, RISCVISD::UMIN_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::UMIN_VL, true, true); case ISD::UMAX: - return lowerToScalableOp(Op, DAG, RISCVISD::UMAX_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::UMAX_VL, true, true); case ISD::FMINNUM: - return lowerToScalableOp(Op, DAG, RISCVISD::FMINNUM_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::FMINNUM_VL, true, true); case ISD::FMAXNUM: - return lowerToScalableOp(Op, DAG, RISCVISD::FMAXNUM_VL); + return lowerToScalableOp(Op, DAG, RISCVISD::FMAXNUM_VL, true, true); case ISD::ABS: return lowerABS(Op, DAG); case ISD::CTLZ_ZERO_UNDEF: @@ -3631,19 +3633,19 @@ case ISD::VP_MERGE: return lowerVPOp(Op, DAG, RISCVISD::VP_MERGE_VL); case ISD::VP_ADD: - return lowerVPOp(Op, DAG, RISCVISD::ADD_VL); + return lowerVPOp(Op, DAG, RISCVISD::ADD_VL, true); case ISD::VP_SUB: - return lowerVPOp(Op, DAG, RISCVISD::SUB_VL); + return lowerVPOp(Op, DAG, RISCVISD::SUB_VL, true); case ISD::VP_MUL: - return lowerVPOp(Op, DAG, RISCVISD::MUL_VL); + return lowerVPOp(Op, DAG, RISCVISD::MUL_VL, true); case ISD::VP_SDIV: - return lowerVPOp(Op, DAG, RISCVISD::SDIV_VL); + return lowerVPOp(Op, DAG, RISCVISD::SDIV_VL, true); case ISD::VP_UDIV: - return lowerVPOp(Op, DAG, RISCVISD::UDIV_VL); + return lowerVPOp(Op, DAG, RISCVISD::UDIV_VL, true); case ISD::VP_SREM: - return lowerVPOp(Op, DAG, RISCVISD::SREM_VL); + return lowerVPOp(Op, DAG, RISCVISD::SREM_VL, true); case ISD::VP_UREM: - return lowerVPOp(Op, DAG, RISCVISD::UREM_VL); + return lowerVPOp(Op, DAG, RISCVISD::UREM_VL, true); case ISD::VP_AND: return lowerLogicVPOp(Op, DAG, RISCVISD::VMAND_VL, RISCVISD::AND_VL); case ISD::VP_OR: @@ -3651,19 +3653,19 @@ case ISD::VP_XOR: return lowerLogicVPOp(Op, DAG, RISCVISD::VMXOR_VL, RISCVISD::XOR_VL); case ISD::VP_ASHR: - return lowerVPOp(Op, DAG, RISCVISD::SRA_VL); + return lowerVPOp(Op, DAG, RISCVISD::SRA_VL, true); case ISD::VP_LSHR: - return lowerVPOp(Op, DAG, RISCVISD::SRL_VL); + return lowerVPOp(Op, DAG, RISCVISD::SRL_VL, true); case ISD::VP_SHL: - return lowerVPOp(Op, DAG, RISCVISD::SHL_VL); + return lowerVPOp(Op, DAG, RISCVISD::SHL_VL, true); case ISD::VP_FADD: - return lowerVPOp(Op, DAG, RISCVISD::FADD_VL); + return lowerVPOp(Op, DAG, RISCVISD::FADD_VL, true); case ISD::VP_FSUB: - return lowerVPOp(Op, DAG, RISCVISD::FSUB_VL); + return lowerVPOp(Op, DAG, RISCVISD::FSUB_VL, true); case ISD::VP_FMUL: - return lowerVPOp(Op, DAG, RISCVISD::FMUL_VL); + return lowerVPOp(Op, DAG, RISCVISD::FMUL_VL, true); case ISD::VP_FDIV: - return lowerVPOp(Op, DAG, RISCVISD::FDIV_VL); + return lowerVPOp(Op, DAG, RISCVISD::FDIV_VL, true); case ISD::VP_FNEG: return lowerVPOp(Op, DAG, RISCVISD::FNEG_VL); case ISD::VP_FMA: @@ -4343,8 +4345,8 @@ DAG.getUNDEF(ContainerVT), SplatZero, VL); MVT MaskContainerVT = ContainerVT.changeVectorElementType(MVT::i1); - SDValue Trunc = - DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT, Src, SplatOne, Mask, VL); + SDValue Trunc = DAG.getNode(RISCVISD::AND_VL, DL, ContainerVT, Src, SplatOne, + Mask, DAG.getUNDEF(ContainerVT), VL); Trunc = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskContainerVT, Trunc, SplatZero, DAG.getCondCode(ISD::SETNE), Mask, VL); if (MaskVT.isFixedLengthVector()) @@ -5807,8 +5809,8 @@ VLMinus1, DAG.getRegister(RISCV::X0, XLenVT)); SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IntVT, Mask, VL); - SDValue Indices = - DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID, Mask, VL); + SDValue Indices = DAG.getNode(RISCVISD::SUB_VL, DL, IntVT, SplatVL, VID, Mask, + DAG.getUNDEF(IntVT), VL); return DAG.getNode(GatherOpc, DL, VecVT, Op.getOperand(0), Indices, Mask, DAG.getUNDEF(VecVT), VL); @@ -6073,7 +6075,8 @@ if (VT.getVectorElementType() == MVT::i1) return lowerToScalableOp(Op, DAG, MaskOpc, /*HasMask*/ false); - return lowerToScalableOp(Op, DAG, VecOpc, /*HasMask*/ true); + return lowerToScalableOp(Op, DAG, VecOpc, /*HasMask*/ true, + /*HasMerge*/ true); } SDValue @@ -6087,7 +6090,7 @@ case ISD::SRL: Opc = RISCVISD::SRL_VL; break; } - return lowerToScalableOp(Op, DAG, Opc); + return lowerToScalableOp(Op, DAG, Opc, true, true); } // Lower vector ABS to smax(X, sub(0, X)). @@ -6107,10 +6110,10 @@ SDValue SplatZero = DAG.getNode( RISCVISD::VMV_V_X_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT), DAG.getConstant(0, DL, Subtarget.getXLenVT())); - SDValue NegX = - DAG.getNode(RISCVISD::SUB_VL, DL, ContainerVT, SplatZero, X, Mask, VL); - SDValue Max = - DAG.getNode(RISCVISD::SMAX_VL, DL, ContainerVT, X, NegX, Mask, VL); + SDValue NegX = DAG.getNode(RISCVISD::SUB_VL, DL, ContainerVT, SplatZero, X, + Mask, DAG.getUNDEF(ContainerVT), VL); + SDValue Max = DAG.getNode(RISCVISD::SMAX_VL, DL, ContainerVT, X, NegX, Mask, + DAG.getUNDEF(ContainerVT), VL); return convertFromScalableVector(VT, Max, DAG, Subtarget); } @@ -6163,8 +6166,8 @@ } SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG, - unsigned NewOpc, - bool HasMask) const { + unsigned NewOpc, bool HasMask, + bool HasMergeOp) const { MVT VT = Op.getSimpleValueType(); MVT ContainerVT = getContainerForFixedLengthVector(VT); @@ -6190,6 +6193,8 @@ std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); if (HasMask) Ops.push_back(Mask); + if (HasMergeOp) + Ops.push_back(DAG.getUNDEF(ContainerVT)); Ops.push_back(VL); SDValue ScalableRes = DAG.getNode(NewOpc, DL, ContainerVT, Ops); @@ -6202,14 +6207,21 @@ // * Fixed-length vectors are converted to their scalable-vector container // types. SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG, - unsigned RISCVISDOpc) const { + unsigned RISCVISDOpc, + bool HasMergeOp) const { SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); SmallVector Ops; + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) + ContainerVT = getContainerForFixedLengthVector(VT); + for (const auto &OpIdx : enumerate(Op->ops())) { SDValue V = OpIdx.value(); assert(!isa(V) && "Unexpected VTSDNode node!"); + if (HasMergeOp && V.getValueType().isScalarInteger()) + Ops.push_back(DAG.getUNDEF(ContainerVT)); // Pass through operands which aren't fixed-length vectors. if (!V.getValueType().isFixedLengthVector()) { Ops.push_back(V); @@ -6226,8 +6238,6 @@ if (!VT.isFixedLengthVector()) return DAG.getNode(RISCVISDOpc, DL, VT, Ops, Op->getFlags()); - MVT ContainerVT = getContainerForFixedLengthVector(VT); - SDValue VPOp = DAG.getNode(RISCVISDOpc, DL, ContainerVT, Ops, Op->getFlags()); return convertFromScalableVector(VT, VPOp, DAG, Subtarget); @@ -6483,7 +6493,7 @@ unsigned VecOpc) const { MVT VT = Op.getSimpleValueType(); if (VT.getVectorElementType() != MVT::i1) - return lowerVPOp(Op, DAG, VecOpc); + return lowerVPOp(Op, DAG, VecOpc, true); // It is safe to drop mask parameter as masked-off elements are undef. SDValue Op1 = Op->getOperand(0); @@ -7281,8 +7291,9 @@ SDValue ThirtyTwoV = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT), DAG.getConstant(32, DL, XLenVT), VL); - SDValue LShr32 = DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, Vec, - ThirtyTwoV, Mask, VL); + SDValue LShr32 = + DAG.getNode(RISCVISD::SRL_VL, DL, ContainerVT, Vec, ThirtyTwoV, Mask, + DAG.getUNDEF(ContainerVT), VL); SDValue EltHi = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, LShr32); @@ -7389,8 +7400,8 @@ SDValue ThirtyTwoV = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, DAG.getUNDEF(VecVT), DAG.getConstant(32, DL, XLenVT), VL); - SDValue LShr32 = - DAG.getNode(RISCVISD::SRL_VL, DL, VecVT, Vec, ThirtyTwoV, Mask, VL); + SDValue LShr32 = DAG.getNode(RISCVISD::SRL_VL, DL, VecVT, Vec, ThirtyTwoV, + Mask, DAG.getUNDEF(VecVT), VL); SDValue EltHi = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, LShr32); Results.push_back( @@ -8299,7 +8310,8 @@ VT.getVectorElementCount()); SDValue Mask = N->getOperand(2); - SDValue VL = N->getOperand(3); + SDValue Merge = N->getOperand(3); + SDValue VL = N->getOperand(4); SDLoc DL(N); @@ -8319,7 +8331,7 @@ else WOpc = IsAdd ? RISCVISD::VWADDU_W_VL : RISCVISD::VWSUBU_W_VL; - return DAG.getNode(WOpc, DL, VT, Op0, Op1, Mask, VL); + return DAG.getNode(WOpc, DL, VT, Op0, Op1, Mask, Merge, VL); } // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar @@ -8334,7 +8346,8 @@ SDValue Op0 = N->getOperand(0); SDValue Op1 = N->getOperand(1); SDValue Mask = N->getOperand(2); - SDValue VL = N->getOperand(3); + SDValue Merge = N->getOperand(3); + SDValue VL = N->getOperand(4); MVT VT = N->getSimpleValueType(0); MVT NarrowVT = Op1.getSimpleValueType(); @@ -8364,7 +8377,7 @@ // Re-introduce narrower extends if needed. if (Op0.getValueType() != NarrowVT) Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL); - return DAG.getNode(VOpc, DL, VT, Op0, Op1, Mask, VL); + return DAG.getNode(VOpc, DL, VT, Op0, Op1, Mask, Merge, VL); } bool IsAdd = N->getOpcode() == RISCVISD::VWADD_W_VL || @@ -8396,7 +8409,7 @@ Op0 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, DAG.getUNDEF(NarrowVT), Op0, VL); - return DAG.getNode(VOpc, DL, VT, Op1, Op0, Mask, VL); + return DAG.getNode(VOpc, DL, VT, Op1, Op0, Mask, Merge, VL); } return SDValue(); @@ -8419,7 +8432,8 @@ return SDValue(); SDValue Mask = N->getOperand(2); - SDValue VL = N->getOperand(3); + SDValue Merge = N->getOperand(3); + SDValue VL = N->getOperand(4); // Make sure the mask and VL match. if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL) @@ -8497,7 +8511,7 @@ unsigned WMulOpc = RISCVISD::VWMULSU_VL; if (!IsVWMULSU) WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; - return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL); + return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, Merge, VL); } static RISCVFPRndMode::RoundingMode matchRoundingOp(SDValue Op) { @@ -9230,7 +9244,7 @@ ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT), ShAmt.getOperand(1), VL); return DAG.getNode(N->getOpcode(), DL, VT, N->getOperand(0), ShAmt, - N->getOperand(2), N->getOperand(3)); + N->getOperand(2), N->getOperand(3), N->getOperand(4)); } break; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -21,24 +21,26 @@ // Helpers to define the VL patterns. //===----------------------------------------------------------------------===// -def SDT_RISCVIntBinOp_VL : SDTypeProfile<1, 4, [SDTCisSameAs<0, 1>, +def SDT_RISCVIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisVec<0>, SDTCisInt<0>, SDTCVecEltisVT<3, i1>, SDTCisSameNumEltsAs<0, 3>, - SDTCisVT<4, XLenVT>]>; + SDTCisSameAs<0, 4>, + SDTCisVT<5, XLenVT>]>; def SDT_RISCVFPUnOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisVec<0>, SDTCisFP<0>, SDTCVecEltisVT<2, i1>, SDTCisSameNumEltsAs<0, 2>, SDTCisVT<3, XLenVT>]>; -def SDT_RISCVFPBinOp_VL : SDTypeProfile<1, 4, [SDTCisSameAs<0, 1>, +def SDT_RISCVFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisVec<0>, SDTCisFP<0>, SDTCVecEltisVT<3, i1>, SDTCisSameNumEltsAs<0, 3>, - SDTCisVT<4, XLenVT>]>; + SDTCisSameAs<0, 4>, + SDTCisVT<5, XLenVT>]>; def SDT_RISCVCopySign_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, @@ -230,12 +232,13 @@ SDTCVecEltisVT<2, i1>, SDTCisVT<3, XLenVT>]>>; -def SDT_RISCVVWBinOp_VL : SDTypeProfile<1, 4, [SDTCisVec<0>, +def SDT_RISCVVWBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisSameNumEltsAs<0, 1>, SDTCisSameAs<1, 2>, SDTCisSameNumEltsAs<1, 3>, SDTCVecEltisVT<3, i1>, - SDTCisVT<4, XLenVT>]>; + SDTCisSameAs<0, 4>, + SDTCisVT<5, XLenVT>]>; def riscv_vwmul_vl : SDNode<"RISCVISD::VWMUL_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>; def riscv_vwmulu_vl : SDNode<"RISCVISD::VWMULU_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>; def riscv_vwmulsu_vl : SDNode<"RISCVISD::VWMULSU_VL", SDT_RISCVVWBinOp_VL>; @@ -244,13 +247,14 @@ def riscv_vwsub_vl : SDNode<"RISCVISD::VWSUB_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>; def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>; -def SDT_RISCVVWBinOpW_VL : SDTypeProfile<1, 4, [SDTCisVec<0>, +def SDT_RISCVVWBinOpW_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisSameNumEltsAs<1, 2>, SDTCisOpSmallerThanOp<2, 1>, SDTCisSameNumEltsAs<1, 3>, SDTCVecEltisVT<3, i1>, - SDTCisVT<4, XLenVT>]>; + SDTCisSameAs<0, 4>, + SDTCisVT<5, XLenVT>]>; def riscv_vwadd_w_vl : SDNode<"RISCVISD::VWADD_W_VL", SDT_RISCVVWBinOpW_VL>; def riscv_vwaddu_w_vl : SDNode<"RISCVISD::VWADDU_W_VL", SDT_RISCVVWBinOpW_VL>; def riscv_vwsub_w_vl : SDNode<"RISCVISD::VWSUB_W_VL", SDT_RISCVVWBinOpW_VL>; @@ -261,27 +265,31 @@ SDTCVecEltisVT<4, i1>, SDTCisSameNumEltsAs<2, 4>, SDTCisVT<5, XLenVT> ]>; -def riscv_mul_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D), +def riscv_mul_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D, + node:$E), (riscv_mul_vl node:$A, node:$B, node:$C, - node:$D), [{ + node:$D, node:$E), [{ return N->hasOneUse(); }]>; -def riscv_vwmul_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D), +def riscv_vwmul_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D, + node:$E), (riscv_vwmul_vl node:$A, node:$B, node:$C, - node:$D), [{ + node:$D, node:$E), [{ return N->hasOneUse(); }]>; -def riscv_vwmulu_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D), +def riscv_vwmulu_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D, + node:$E), (riscv_vwmulu_vl node:$A, node:$B, node:$C, - node:$D), [{ + node:$D, node:$E), [{ return N->hasOneUse(); }]>; -def riscv_vwmulsu_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D), +def riscv_vwmulsu_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D, + node:$E), (riscv_vwmulsu_vl node:$A, node:$B, node:$C, - node:$D), [{ + node:$D, node:$E), [{ return N->hasOneUse(); }]>; @@ -331,15 +339,17 @@ ValueType mask_type, int sew, LMULInfo vlmul, + VReg result_reg_class, VReg op1_reg_class, VReg op2_reg_class> { def : Pat<(result_type (vop (op1_type op1_reg_class:$rs1), (op2_type op2_reg_class:$rs2), (mask_type V0), + (result_type result_reg_class:$merge), VLOpFrag)), (!cast(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK") - (result_type (IMPLICIT_DEF)), + result_reg_class:$merge, op1_reg_class:$rs1, op2_reg_class:$rs2, (mask_type V0), GPR:$vl, sew, TAIL_AGNOSTIC)>; @@ -354,6 +364,7 @@ ValueType mask_type, int sew, LMULInfo vlmul, + VReg result_reg_class, VReg vop_reg_class, ComplexPattern SplatPatKind, DAGOperand xop_kind> { @@ -361,9 +372,10 @@ (vop1_type vop_reg_class:$rs1), (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))), (mask_type V0), + (result_type result_reg_class:$merge), VLOpFrag)), (!cast(instruction_name#_#suffix#_# vlmul.MX#"_MASK") - (result_type (IMPLICIT_DEF)), + result_reg_class:$merge, vop_reg_class:$rs1, xop_kind:$rs2, (mask_type V0), GPR:$vl, sew, TAIL_AGNOSTIC)>; @@ -373,10 +385,12 @@ foreach vti = AllIntegerVectors in { defm : VPatBinaryVL_V; + vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass, + vti.RegClass>; defm : VPatBinaryVL_XI; + vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass, + SplatPat, GPR>; } } @@ -386,7 +400,7 @@ foreach vti = AllIntegerVectors in { defm : VPatBinaryVL_XI(SplatPat#_#ImmType), ImmType>; } @@ -398,10 +412,12 @@ defvar wti = VtiToWti.Wti; defm : VPatBinaryVL_V; + vti.Log2SEW, vti.LMul, wti.RegClass, vti.RegClass, + vti.RegClass>; defm : VPatBinaryVL_XI; + vti.Log2SEW, vti.LMul, wti.RegClass, vti.RegClass, + SplatPat, GPR>; } } multiclass VPatBinaryWVL_VV_VX_WV_WX; + vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass, + vti.RegClass>; defm : VPatBinaryVL_XI; + vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass, + SplatPat, GPR>; } } @@ -426,14 +444,16 @@ ValueType mask_type, int sew, LMULInfo vlmul, + VReg result_reg_class, VReg vop_reg_class, RegisterClass scalar_reg_class> { def : Pat<(result_type (vop (vop_type vop_reg_class:$rs1), (vop_type (SplatFPOp scalar_reg_class:$rs2)), (mask_type V0), + (result_type result_reg_class:$merge), VLOpFrag)), (!cast(instruction_name#"_"#vlmul.MX#"_MASK") - (result_type (IMPLICIT_DEF)), + result_reg_class:$merge, vop_reg_class:$rs1, scalar_reg_class:$rs2, (mask_type V0), GPR:$vl, sew, TAIL_AGNOSTIC)>; @@ -443,10 +463,12 @@ foreach vti = AllFloatVectors in { defm : VPatBinaryVL_V; + vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass, + vti.RegClass>; defm : VPatBinaryVL_VF; + vti.LMul, vti.RegClass, vti.RegClass, + vti.ScalarRegClass>; } } @@ -455,9 +477,10 @@ def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2), fvti.RegClass:$rs1, (fvti.Mask V0), + (fvti.Vector fvti.RegClass:$merge), VLOpFrag)), (!cast(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX#"_MASK") - (fvti.Vector (IMPLICIT_DEF)), + fvti.RegClass:$merge, fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW, TAIL_AGNOSTIC)>; } @@ -786,7 +809,7 @@ (fvti.Mask true_mask), VLOpFrag)), (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1), (fvti.Mask true_mask), VLOpFrag)), - (fwti.Mask true_mask), VLOpFrag)), + (fwti.Mask true_mask), srcvalue, VLOpFrag)), (!cast(instruction_name#"_VV_"#fvti.LMul.MX) fvti.RegClass:$rs2, fvti.RegClass:$rs1, GPR:$vl, fvti.Log2SEW)>; @@ -794,7 +817,7 @@ (fvti.Mask true_mask), VLOpFrag)), (fwti.Vector (extop (fvti.Vector (SplatFPOp fvti.ScalarRegClass:$rs1)), (fvti.Mask true_mask), VLOpFrag)), - (fwti.Mask true_mask), VLOpFrag)), + (fwti.Mask true_mask), srcvalue, VLOpFrag)), (!cast(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX) fvti.RegClass:$rs2, fvti.ScalarRegClass:$rs1, GPR:$vl, fvti.Log2SEW)>; @@ -808,14 +831,14 @@ def : Pat<(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2), (fwti.Vector (extop (fvti.Vector fvti.RegClass:$rs1), (fvti.Mask true_mask), VLOpFrag)), - (fwti.Mask true_mask), VLOpFrag)), + (fwti.Mask true_mask), srcvalue, VLOpFrag)), (!cast(instruction_name#"_WV_"#fvti.LMul.MX) fwti.RegClass:$rs2, fvti.RegClass:$rs1, GPR:$vl, fvti.Log2SEW)>; def : Pat<(fwti.Vector (op (fwti.Vector fwti.RegClass:$rs2), (fwti.Vector (extop (fvti.Vector (SplatFPOp fvti.ScalarRegClass:$rs1)), (fvti.Mask true_mask), VLOpFrag)), - (fwti.Mask true_mask), VLOpFrag)), + (fwti.Mask true_mask), srcvalue, VLOpFrag)), (!cast(instruction_name#"_W"#fvti.ScalarSuffix#"_"#fvti.LMul.MX) fwti.RegClass:$rs2, fvti.ScalarRegClass:$rs1, GPR:$vl, fvti.Log2SEW)>; @@ -837,7 +860,7 @@ (op (wti.Vector wti.RegClass:$rs2), (wti.Vector (extop (vti.Vector (SplatPat GPR:$rs1)), (vti.Mask true_mask), VLOpFrag)), - (wti.Mask true_mask), VLOpFrag), + (wti.Mask true_mask), srcvalue, VLOpFrag), (vti.Mask true_mask), VLOpFrag)), (!cast(instruction_name#"_WX_"#vti.LMul.MX) wti.RegClass:$rs2, GPR:$rs1, GPR:$vl, vti.Log2SEW)>; @@ -853,8 +876,8 @@ (op vti.RegClass:$rs2, (riscv_mul_vl_oneuse vti.RegClass:$rs1, vti.RegClass:$rd, - (vti.Mask true_mask), VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), + (vti.Mask true_mask), srcvalue, VLOpFrag), + (vti.Mask true_mask), srcvalue, VLOpFrag)), (!cast(instruction_name#"_VV_"# suffix) vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; @@ -864,8 +887,8 @@ (op vti.RegClass:$rs2, (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rd, - (vti.Mask true_mask), VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), + (vti.Mask true_mask), srcvalue, VLOpFrag), + (vti.Mask true_mask), srcvalue, VLOpFrag)), (!cast(instruction_name#"_VX_" # suffix) vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; @@ -880,8 +903,8 @@ (riscv_add_vl wti.RegClass:$rd, (op1 vti.RegClass:$rs1, (vti.Vector vti.RegClass:$rs2), - (vti.Mask true_mask), VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), + (vti.Mask true_mask), srcvalue, VLOpFrag), + (vti.Mask true_mask), srcvalue, VLOpFrag)), (!cast(instruction_name#"_VV_" # vti.LMul.MX) wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; @@ -889,8 +912,8 @@ (riscv_add_vl wti.RegClass:$rd, (op1 (SplatPat XLenVT:$rs1), (vti.Vector vti.RegClass:$rs2), - (vti.Mask true_mask), VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), + (vti.Mask true_mask), srcvalue, VLOpFrag), + (vti.Mask true_mask), srcvalue, VLOpFrag)), (!cast(instruction_name#"_VX_" # vti.LMul.MX) wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; @@ -903,12 +926,12 @@ defvar wti = vtiTowti.Wti; def : Pat<(vti.Vector (riscv_trunc_vector_vl (wti.Vector (op wti.RegClass:$rs1, (SplatPat XLenVT:$rs2), - true_mask, VLOpFrag)), true_mask, VLOpFrag)), + true_mask, srcvalue, VLOpFrag)), true_mask, VLOpFrag)), (!cast(instruction_name#"_WX_"#vti.LMul.MX) wti.RegClass:$rs1, GPR:$rs2, GPR:$vl, vti.Log2SEW)>; def : Pat<(vti.Vector (riscv_trunc_vector_vl (wti.Vector (op wti.RegClass:$rs1, (SplatPat_uimm5 uimm5:$rs2), - true_mask, VLOpFrag)), true_mask, VLOpFrag)), + true_mask, srcvalue, VLOpFrag)), true_mask, VLOpFrag)), (!cast(instruction_name#"_WI_"#vti.LMul.MX) wti.RegClass:$rs1, uimm5:$rs2, GPR:$vl, vti.Log2SEW)>; } @@ -992,15 +1015,15 @@ foreach vti = AllIntegerVectors in { def : Pat<(riscv_sub_vl (vti.Vector (SplatPat (XLenVT GPR:$rs2))), (vti.Vector vti.RegClass:$rs1), (vti.Mask V0), - VLOpFrag), + vti.RegClass:$merge, VLOpFrag), (!cast("PseudoVRSUB_VX_"# vti.LMul.MX#"_MASK") - (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, GPR:$rs2, + vti.RegClass:$merge, vti.RegClass:$rs1, GPR:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; def : Pat<(riscv_sub_vl (vti.Vector (SplatPat_simm5 simm5:$rs2)), (vti.Vector vti.RegClass:$rs1), (vti.Mask V0), - VLOpFrag), + vti.RegClass:$merge, VLOpFrag), (!cast("PseudoVRSUB_VI_"# vti.LMul.MX#"_MASK") - (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, simm5:$rs2, + vti.RegClass:$merge, vti.RegClass:$rs1, simm5:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; } @@ -1038,8 +1061,7 @@ // Emit shift by 1 as an add since it might be faster. def : Pat<(riscv_shl_vl (vti.Vector vti.RegClass:$rs1), (riscv_vmv_v_x_vl (vti.Vector undef), 1, (XLenVT srcvalue)), - (vti.Mask true_mask), - VLOpFrag), + (vti.Mask true_mask), srcvalue, VLOpFrag), (!cast("PseudoVADD_VV_"# vti.LMul.MX) vti.RegClass:$rs1, vti.RegClass:$rs1, GPR:$vl, vti.Log2SEW)>; } @@ -1140,8 +1162,10 @@ (riscv_add_vl wti.RegClass:$rd, (riscv_vwmulsu_vl_oneuse (vti.Vector vti.RegClass:$rs1), (SplatPat XLenVT:$rs2), - (vti.Mask true_mask), VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), + (vti.Mask true_mask), + srcvalue, + VLOpFrag), + (vti.Mask true_mask), srcvalue, VLOpFrag)), (!cast("PseudoVWMACCUS_VX_" # vti.LMul.MX) wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;