diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -675,6 +675,9 @@ } } + /// Test if this node is a vector predication operation. + bool isVPOpcode() const { return ISD::isVPOpcode(getOpcode()); } + /// Test if this node has a post-isel opcode, directly /// corresponding to a MachineInstr opcode. bool isMachineOpcode() const { return NodeType < 0; } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -85,23 +85,18 @@ case ISD::STRICT_FSETCCS: case ISD::SETCC: Res = PromoteIntRes_SETCC(N); break; case ISD::SMIN: - case ISD::SMAX: - Res = PromoteIntRes_SExtIntBinOp(N, /*IsVP*/ false); - break; + case ISD::SMAX: Res = PromoteIntRes_SExtIntBinOp(N); break; case ISD::UMIN: case ISD::UMAX: Res = PromoteIntRes_UMINUMAX(N); break; case ISD::SHL: - Res = PromoteIntRes_SHL(N, /*IsVP*/ false); - break; + case ISD::VP_SHL: Res = PromoteIntRes_SHL(N); break; case ISD::SIGN_EXTEND_INREG: Res = PromoteIntRes_SIGN_EXTEND_INREG(N); break; case ISD::SRA: - Res = PromoteIntRes_SRA(N, /*IsVP*/ false); - break; + case ISD::VP_ASHR: Res = PromoteIntRes_SRA(N); break; case ISD::SRL: - Res = PromoteIntRes_SRL(N, /*IsVP*/ false); - break; + case ISD::VP_LSHR: Res = PromoteIntRes_SRL(N); break; case ISD::TRUNCATE: Res = PromoteIntRes_TRUNCATE(N); break; case ISD::UNDEF: Res = PromoteIntRes_UNDEF(N); break; case ISD::VAARG: Res = PromoteIntRes_VAARG(N); break; @@ -157,18 +152,22 @@ case ISD::ADD: case ISD::SUB: case ISD::MUL: - Res = PromoteIntRes_SimpleIntBinOp(N, /*IsVP*/ false); - break; + case ISD::VP_AND: + case ISD::VP_OR: + case ISD::VP_XOR: + case ISD::VP_ADD: + case ISD::VP_SUB: + case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break; case ISD::SDIV: case ISD::SREM: - Res = PromoteIntRes_SExtIntBinOp(N, /*IsVP*/ false); - break; + case ISD::VP_SDIV: + case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break; case ISD::UDIV: case ISD::UREM: - Res = PromoteIntRes_ZExtIntBinOp(N, /*IsVP*/ false); - break; + case ISD::VP_UDIV: + case ISD::VP_UREM: Res = PromoteIntRes_ZExtIntBinOp(N); break; case ISD::SADDO: case ISD::SSUBO: Res = PromoteIntRes_SADDSUBO(N, ResNo); break; @@ -263,32 +262,6 @@ case ISD::FSHR: Res = PromoteIntRes_FunnelShift(N); break; - - case ISD::VP_AND: - case ISD::VP_OR: - case ISD::VP_XOR: - case ISD::VP_ADD: - case ISD::VP_SUB: - case ISD::VP_MUL: - Res = PromoteIntRes_SimpleIntBinOp(N, /*IsVP*/ true); - break; - case ISD::VP_SDIV: - case ISD::VP_SREM: - Res = PromoteIntRes_SExtIntBinOp(N, /*IsVP*/ true); - break; - case ISD::VP_UDIV: - case ISD::VP_UREM: - Res = PromoteIntRes_ZExtIntBinOp(N, /*IsVP*/ true); - break; - case ISD::VP_SHL: - Res = PromoteIntRes_SHL(N, /*IsVP*/ true); - break; - case ISD::VP_ASHR: - Res = PromoteIntRes_SRA(N, /*IsVP*/ true); - break; - case ISD::VP_LSHR: - Res = PromoteIntRes_SRL(N, /*IsVP*/ true); - break; } // If the result is null then the sub-method took care of registering it. @@ -1194,12 +1167,12 @@ return DAG.getSExtOrTrunc(SetCC, dl, NVT); } -SDValue DAGTypeLegalizer::PromoteIntRes_SHL(SDNode *N, bool IsVP) { +SDValue DAGTypeLegalizer::PromoteIntRes_SHL(SDNode *N) { SDValue LHS = GetPromotedInteger(N->getOperand(0)); SDValue RHS = N->getOperand(1); if (getTypeAction(RHS.getValueType()) == TargetLowering::TypePromoteInteger) RHS = ZExtPromotedInteger(RHS); - if (!IsVP) + if (N->getOpcode() != ISD::VP_SHL) return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS); return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS, N->getOperand(2), N->getOperand(3)); @@ -1211,34 +1184,40 @@ Op.getValueType(), Op, N->getOperand(1)); } -SDValue DAGTypeLegalizer::PromoteIntRes_SimpleIntBinOp(SDNode *N, bool IsVP) { +SDValue DAGTypeLegalizer::PromoteIntRes_SimpleIntBinOp(SDNode *N) { // The input may have strange things in the top bits of the registers, but // these operations don't care. They may have weird bits going out, but // that too is okay if they are integer operations. SDValue LHS = GetPromotedInteger(N->getOperand(0)); SDValue RHS = GetPromotedInteger(N->getOperand(1)); - if (!IsVP) + if (N->getNumOperands() == 2) return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS); + assert(N->getNumOperands() == 4 && "Unexpected number of operands!"); + assert(N->isVPOpcode() && "Expected VP opcode"); return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS, N->getOperand(2), N->getOperand(3)); } -SDValue DAGTypeLegalizer::PromoteIntRes_SExtIntBinOp(SDNode *N, bool IsVP) { +SDValue DAGTypeLegalizer::PromoteIntRes_SExtIntBinOp(SDNode *N) { // Sign extend the input. SDValue LHS = SExtPromotedInteger(N->getOperand(0)); SDValue RHS = SExtPromotedInteger(N->getOperand(1)); - if (!IsVP) + if (N->getNumOperands() == 2) return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS); + assert(N->getNumOperands() == 4 && "Unexpected number of operands!"); + assert(N->isVPOpcode() && "Expected VP opcode"); return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS, N->getOperand(2), N->getOperand(3)); } -SDValue DAGTypeLegalizer::PromoteIntRes_ZExtIntBinOp(SDNode *N, bool IsVP) { +SDValue DAGTypeLegalizer::PromoteIntRes_ZExtIntBinOp(SDNode *N) { // Zero extend the input. SDValue LHS = ZExtPromotedInteger(N->getOperand(0)); SDValue RHS = ZExtPromotedInteger(N->getOperand(1)); - if (!IsVP) + if (N->getNumOperands() == 2) return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS); + assert(N->getNumOperands() == 4 && "Unexpected number of operands!"); + assert(N->isVPOpcode() && "Expected VP opcode"); return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS, N->getOperand(2), N->getOperand(3)); } @@ -1252,25 +1231,25 @@ LHS.getValueType(), LHS, RHS); } -SDValue DAGTypeLegalizer::PromoteIntRes_SRA(SDNode *N, bool IsVP) { +SDValue DAGTypeLegalizer::PromoteIntRes_SRA(SDNode *N) { // The input value must be properly sign extended. SDValue LHS = SExtPromotedInteger(N->getOperand(0)); SDValue RHS = N->getOperand(1); if (getTypeAction(RHS.getValueType()) == TargetLowering::TypePromoteInteger) RHS = ZExtPromotedInteger(RHS); - if (!IsVP) + if (N->getOpcode() != ISD::VP_ASHR) return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS); return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS, N->getOperand(2), N->getOperand(3)); } -SDValue DAGTypeLegalizer::PromoteIntRes_SRL(SDNode *N, bool IsVP) { +SDValue DAGTypeLegalizer::PromoteIntRes_SRL(SDNode *N) { // The input value must be properly zero extended. SDValue LHS = ZExtPromotedInteger(N->getOperand(0)); SDValue RHS = N->getOperand(1); if (getTypeAction(RHS.getValueType()) == TargetLowering::TypePromoteInteger) RHS = ZExtPromotedInteger(RHS); - if (!IsVP) + if (N->getOpcode() != ISD::VP_LSHR) return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS); return DAG.getNode(N->getOpcode(), SDLoc(N), LHS.getValueType(), LHS, RHS, N->getOperand(2), N->getOperand(3)); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -337,14 +337,14 @@ SDValue PromoteIntRes_Select(SDNode *N); SDValue PromoteIntRes_SELECT_CC(SDNode *N); SDValue PromoteIntRes_SETCC(SDNode *N); - SDValue PromoteIntRes_SHL(SDNode *N, bool IsVP); - SDValue PromoteIntRes_SimpleIntBinOp(SDNode *N, bool IsVP); - SDValue PromoteIntRes_ZExtIntBinOp(SDNode *N, bool IsVP); - SDValue PromoteIntRes_SExtIntBinOp(SDNode *N, bool IsVP); + SDValue PromoteIntRes_SHL(SDNode *N); + SDValue PromoteIntRes_SimpleIntBinOp(SDNode *N); + SDValue PromoteIntRes_ZExtIntBinOp(SDNode *N); + SDValue PromoteIntRes_SExtIntBinOp(SDNode *N); SDValue PromoteIntRes_UMINUMAX(SDNode *N); SDValue PromoteIntRes_SIGN_EXTEND_INREG(SDNode *N); - SDValue PromoteIntRes_SRA(SDNode *N, bool IsVP); - SDValue PromoteIntRes_SRL(SDNode *N, bool IsVP); + SDValue PromoteIntRes_SRA(SDNode *N); + SDValue PromoteIntRes_SRL(SDNode *N); SDValue PromoteIntRes_TRUNCATE(SDNode *N); SDValue PromoteIntRes_UADDSUBO(SDNode *N, unsigned ResNo); SDValue PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo); @@ -825,7 +825,7 @@ // Vector Result Splitting: <128 x ty> -> 2 x <64 x ty>. void SplitVectorResult(SDNode *N, unsigned ResNo); - void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi, bool IsVP); + void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, SDValue &Hi); @@ -922,7 +922,7 @@ SDValue WidenVecRes_VECTOR_SHUFFLE(ShuffleVectorSDNode *N); SDValue WidenVecRes_Ternary(SDNode *N); - SDValue WidenVecRes_Binary(SDNode *N, bool IsVP); + SDValue WidenVecRes_Binary(SDNode *N); SDValue WidenVecRes_BinaryCanTrap(SDNode *N); SDValue WidenVecRes_BinaryWithExtraScalarOp(SDNode *N); SDValue WidenVecRes_StrictFP(SDNode *N); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1008,31 +1008,31 @@ SplitVecRes_ExtendOp(N, Lo, Hi); break; - case ISD::ADD: - case ISD::SUB: - case ISD::MUL: + case ISD::ADD: case ISD::VP_ADD: + case ISD::SUB: case ISD::VP_SUB: + case ISD::MUL: case ISD::VP_MUL: case ISD::MULHS: case ISD::MULHU: - case ISD::FADD: - case ISD::FSUB: - case ISD::FMUL: + case ISD::FADD: case ISD::VP_FADD: + case ISD::FSUB: case ISD::VP_FSUB: + case ISD::FMUL: case ISD::VP_FMUL: case ISD::FMINNUM: case ISD::FMAXNUM: case ISD::FMINIMUM: case ISD::FMAXIMUM: - case ISD::SDIV: - case ISD::UDIV: - case ISD::FDIV: + case ISD::SDIV: case ISD::VP_SDIV: + case ISD::UDIV: case ISD::VP_UDIV: + case ISD::FDIV: case ISD::VP_FDIV: case ISD::FPOW: - case ISD::AND: - case ISD::OR: - case ISD::XOR: - case ISD::SHL: - case ISD::SRA: - case ISD::SRL: - case ISD::UREM: - case ISD::SREM: - case ISD::FREM: + case ISD::AND: case ISD::VP_AND: + case ISD::OR: case ISD::VP_OR: + case ISD::XOR: case ISD::VP_XOR: + case ISD::SHL: case ISD::VP_SHL: + case ISD::SRA: case ISD::VP_ASHR: + case ISD::SRL: case ISD::VP_LSHR: + case ISD::UREM: case ISD::VP_UREM: + case ISD::SREM: case ISD::VP_SREM: + case ISD::FREM: case ISD::VP_FREM: case ISD::SMIN: case ISD::SMAX: case ISD::UMIN: @@ -1045,7 +1045,7 @@ case ISD::USHLSAT: case ISD::ROTL: case ISD::ROTR: - SplitVecRes_BinOp(N, Lo, Hi, /*IsVP*/ false); + SplitVecRes_BinOp(N, Lo, Hi); break; case ISD::FMA: case ISD::FSHL: @@ -1082,26 +1082,6 @@ case ISD::UDIVFIXSAT: SplitVecRes_FIX(N, Lo, Hi); break; - case ISD::VP_ADD: - case ISD::VP_AND: - case ISD::VP_MUL: - case ISD::VP_OR: - case ISD::VP_SUB: - case ISD::VP_XOR: - case ISD::VP_SHL: - case ISD::VP_LSHR: - case ISD::VP_ASHR: - case ISD::VP_SDIV: - case ISD::VP_UDIV: - case ISD::VP_SREM: - case ISD::VP_UREM: - case ISD::VP_FADD: - case ISD::VP_FSUB: - case ISD::VP_FMUL: - case ISD::VP_FDIV: - case ISD::VP_FREM: - SplitVecRes_BinOp(N, Lo, Hi, /*IsVP*/ true); - break; } // If Lo/Hi is null, the sub-method took care of registering results etc. @@ -1133,8 +1113,7 @@ } } -void DAGTypeLegalizer::SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi, - bool IsVP) { +void DAGTypeLegalizer::SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi) { SDValue LHSLo, LHSHi; GetSplitVector(N->getOperand(0), LHSLo, LHSHi); SDValue RHSLo, RHSHi; @@ -1143,12 +1122,15 @@ const SDNodeFlags Flags = N->getFlags(); unsigned Opcode = N->getOpcode(); - if (!IsVP) { + if (N->getNumOperands() == 2) { Lo = DAG.getNode(Opcode, dl, LHSLo.getValueType(), LHSLo, RHSLo, Flags); Hi = DAG.getNode(Opcode, dl, LHSHi.getValueType(), LHSHi, RHSHi, Flags); return; } + assert(N->getNumOperands() == 4 && "Unexpected number of operands!"); + assert(N->isVPOpcode() && "Expected VP opcode"); + // Split the mask. SDValue MaskLo, MaskHi; SDValue Mask = N->getOperand(2); @@ -3064,17 +3046,17 @@ Res = WidenVecRes_MGATHER(cast(N)); break; - case ISD::ADD: - case ISD::AND: - case ISD::MUL: + case ISD::ADD: case ISD::VP_ADD: + case ISD::AND: case ISD::VP_AND: + case ISD::MUL: case ISD::VP_MUL: case ISD::MULHS: case ISD::MULHU: - case ISD::OR: - case ISD::SUB: - case ISD::XOR: - case ISD::SHL: - case ISD::SRA: - case ISD::SRL: + case ISD::OR: case ISD::VP_OR: + case ISD::SUB: case ISD::VP_SUB: + case ISD::XOR: case ISD::VP_XOR: + case ISD::SHL: case ISD::VP_SHL: + case ISD::SRA: case ISD::VP_ASHR: + case ISD::SRL: case ISD::VP_LSHR: case ISD::FMINNUM: case ISD::FMAXNUM: case ISD::FMINIMUM: @@ -3091,7 +3073,21 @@ case ISD::USHLSAT: case ISD::ROTL: case ISD::ROTR: - Res = WidenVecRes_Binary(N, /*IsVP*/ false); + // Vector-predicated binary op widening. Note that -- unlike the + // unpredicated versions -- we don't have to worry about trapping on + // operations like UDIV, FADD, etc., as we pass on the original vector + // length parameter. This means the widened elements containing garbage + // aren't active. + case ISD::VP_SDIV: + case ISD::VP_UDIV: + case ISD::VP_SREM: + case ISD::VP_UREM: + case ISD::VP_FADD: + case ISD::VP_FSUB: + case ISD::VP_FMUL: + case ISD::VP_FDIV: + case ISD::VP_FREM: + Res = WidenVecRes_Binary(N); break; case ISD::FADD: @@ -3215,31 +3211,6 @@ case ISD::FSHR: Res = WidenVecRes_Ternary(N); break; - case ISD::VP_ADD: - case ISD::VP_AND: - case ISD::VP_MUL: - case ISD::VP_OR: - case ISD::VP_SUB: - case ISD::VP_XOR: - case ISD::VP_SHL: - case ISD::VP_LSHR: - case ISD::VP_ASHR: - case ISD::VP_SDIV: - case ISD::VP_UDIV: - case ISD::VP_SREM: - case ISD::VP_UREM: - case ISD::VP_FADD: - case ISD::VP_FSUB: - case ISD::VP_FMUL: - case ISD::VP_FDIV: - case ISD::VP_FREM: - // Vector-predicated binary op widening. Note that -- unlike the - // unpredicated versions -- we don't have to worry about trapping on - // operations like UDIV, FADD, etc., as we pass on the original vector - // length parameter. This means the widened elements containing garbage - // aren't active. - Res = WidenVecRes_Binary(N, /*IsVP*/ true); - break; } // If Res is null, the sub-method took care of registering the result. @@ -3257,15 +3228,19 @@ return DAG.getNode(N->getOpcode(), dl, WidenVT, InOp1, InOp2, InOp3); } -SDValue DAGTypeLegalizer::WidenVecRes_Binary(SDNode *N, bool IsVP) { +SDValue DAGTypeLegalizer::WidenVecRes_Binary(SDNode *N) { // Binary op widening. SDLoc dl(N); EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); SDValue InOp1 = GetWidenedVector(N->getOperand(0)); SDValue InOp2 = GetWidenedVector(N->getOperand(1)); - if (!IsVP) + if (N->getNumOperands() == 2) return DAG.getNode(N->getOpcode(), dl, WidenVT, InOp1, InOp2, N->getFlags()); + + assert(N->getNumOperands() == 4 && "Unexpected number of operands!"); + assert(N->isVPOpcode() && "Expected VP opcode"); + // For VP operations, we must also widen the mask. Note that the mask type // may not actually need widening, leading it be split along with the VP // operation.