diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10142,7 +10142,7 @@ if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) && TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) { SDValue Op = N0.getOperand(0); - Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); + Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT); AddToWorklist(Op.getNode()); SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT); // Transfer the debug info; the new node is equivalent to N0. @@ -10154,7 +10154,7 @@ if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) { SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT); AddToWorklist(Op.getNode()); - SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); + SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT); // We may safely transfer the debug info describing the truncate node over // to the equivalent and operation. DAG.transferDbgValues(N0, And); @@ -10283,7 +10283,7 @@ // zext(setcc) -> zext_in_reg(vsetcc) for vectors. SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0), N0.getOperand(1), N0.getOperand(2)); - return DAG.getZeroExtendInReg(VSetCC, DL, MVT::i1); + return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType()); } // If the desired elements are smaller or larger than the source @@ -10293,8 +10293,8 @@ SDValue VsetCC = DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0), N0.getOperand(1), N0.getOperand(2)); - return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), - DL, MVT::i1); + return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL, + N0.getValueType()); } // zext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc @@ -10812,7 +10812,7 @@ // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero. if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, EVTBits - 1))) - return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT.getScalarType()); + return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT); // fold operands of sext_in_reg based on knowledge that the top bits are not // demanded. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -933,7 +933,7 @@ Result.getValueType(), Result, DAG.getValueType(SrcVT)); else - ValRes = DAG.getZeroExtendInReg(Result, dl, SrcVT.getScalarType()); + ValRes = DAG.getZeroExtendInReg(Result, dl, SrcVT); Value = ValRes; Chain = Result.getValue(1); break; @@ -3531,8 +3531,9 @@ SDValue Overflow = DAG.getSetCC(dl, SetCCType, Sum, LHS, CC); // Add of the sum and the carry. + SDValue One = DAG.getConstant(1, dl, VT); SDValue CarryExt = - DAG.getZeroExtendInReg(DAG.getZExtOrTrunc(Carry, dl, VT), dl, MVT::i1); + DAG.getNode(ISD::AND, dl, VT, DAG.getZExtOrTrunc(Carry, dl, VT), One); SDValue Sum2 = DAG.getNode(Op, dl, VT, Sum, CarryExt); // Second check for overflow. If we are adding, we can only overflow if the diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -615,8 +615,7 @@ return DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, NVT, Res, DAG.getValueType(N->getOperand(0).getValueType())); if (N->getOpcode() == ISD::ZERO_EXTEND) - return DAG.getZeroExtendInReg(Res, dl, - N->getOperand(0).getValueType().getScalarType()); + return DAG.getZeroExtendInReg(Res, dl, N->getOperand(0).getValueType()); assert(N->getOpcode() == ISD::ANY_EXTEND && "Unknown integer extension!"); return Res; } @@ -1169,7 +1168,7 @@ // Calculate the overflow flag: zero extend the arithmetic result from // the original type. - SDValue Ofl = DAG.getZeroExtendInReg(Res, dl, OVT.getScalarType()); + SDValue Ofl = DAG.getZeroExtendInReg(Res, dl, OVT); // Overflowed if and only if this is not equal to Res. Ofl = DAG.getSetCC(dl, N->getValueType(1), Ofl, Res, ISD::SETNE); @@ -1784,8 +1783,7 @@ SDLoc dl(N); SDValue Op = GetPromotedInteger(N->getOperand(0)); Op = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Op); - return DAG.getZeroExtendInReg(Op, dl, - N->getOperand(0).getValueType().getScalarType()); + return DAG.getZeroExtendInReg(Op, dl, N->getOperand(0).getValueType()); } SDValue DAGTypeLegalizer::PromoteIntOp_ADDSUBCARRY(SDNode *N, unsigned OpNo) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -265,7 +265,7 @@ EVT OldVT = Op.getValueType(); SDLoc dl(Op); Op = GetPromotedInteger(Op); - return DAG.getZeroExtendInReg(Op, dl, OldVT.getScalarType()); + return DAG.getZeroExtendInReg(Op, dl, OldVT); } // Get a promoted operand and sign or zero extend it to the final size @@ -279,7 +279,7 @@ if (TLI.isSExtCheaperThanZExt(OldVT, Op.getValueType())) return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Op.getValueType(), Op, DAG.getValueType(OldVT)); - return DAG.getZeroExtendInReg(Op, DL, OldVT.getScalarType()); + return DAG.getZeroExtendInReg(Op, DL, OldVT); } // Integer Result Promotion. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1167,15 +1167,21 @@ } SDValue SelectionDAG::getZeroExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) { - assert(!VT.isVector() && - "getZeroExtendInReg should use the vector element type instead of " - "the vector type!"); - if (Op.getValueType().getScalarType() == VT) return Op; - unsigned BitWidth = Op.getScalarValueSizeInBits(); - APInt Imm = APInt::getLowBitsSet(BitWidth, - VT.getSizeInBits()); - return getNode(ISD::AND, DL, Op.getValueType(), Op, - getConstant(Imm, DL, Op.getValueType())); + EVT OpVT = Op.getValueType(); + assert(VT.isInteger() && OpVT.isInteger() && + "Cannot getZeroExtendInReg FP types"); + assert(VT.isVector() == OpVT.isVector() && + "getZeroExtendInReg type should be vector iff the operand " + "type is vector!"); + assert((!VT.isVector() || + VT.getVectorNumElements() == OpVT.getVectorNumElements()) && + "Vector element counts must match in getZeroExtendInReg"); + assert(VT.bitsLE(OpVT) && "Not extending!"); + if (OpVT == VT) + return Op; + APInt Imm = APInt::getLowBitsSet(OpVT.getScalarSizeInBits(), + VT.getScalarSizeInBits()); + return getNode(ISD::AND, DL, OpVT, Op, getConstant(Imm, DL, OpVT)); } SDValue SelectionDAG::getPtrExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT) { diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1727,8 +1727,7 @@ // If the input sign bit is known zero, convert this into a zero extension. if (Known.Zero[ExVTBits - 1]) - return TLO.CombineTo( - Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT.getScalarType())); + return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT)); APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits); if (Known.One[ExVTBits - 1]) { // Input sign bit known set diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -41259,7 +41259,7 @@ case ISD::ANY_EXTEND: return Op; case ISD::ZERO_EXTEND: - return DAG.getZeroExtendInReg(Op, DL, NarrowVT.getScalarType()); + return DAG.getZeroExtendInReg(Op, DL, NarrowVT); case ISD::SIGN_EXTEND: return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, Op, DAG.getValueType(NarrowVT)); @@ -44870,7 +44870,7 @@ SDValue Res = DAG.getSetCC(dl, VT, N0.getOperand(0), N0.getOperand(1), CC); if (N->getOpcode() == ISD::ZERO_EXTEND) - Res = DAG.getZeroExtendInReg(Res, dl, N0.getValueType().getScalarType()); + Res = DAG.getZeroExtendInReg(Res, dl, N0.getValueType()); return Res; }