diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4399,11 +4399,11 @@ return C->getZExtValue(); } -static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG, +static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG, bool isSigned) { - EVT VT = N->getValueType(0); + EVT VT = N.getValueType(); - if (N->getOpcode() != ISD::BUILD_VECTOR) + if (N.getOpcode() != ISD::BUILD_VECTOR) return false; for (const SDValue &Elt : N->op_values()) { @@ -4425,58 +4425,65 @@ return true; } -static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) { - if (ISD::isExtOpcode(N->getOpcode())) - return addRequiredExtensionForVectorMULL(N->getOperand(0), DAG, - N->getOperand(0)->getValueType(0), - N->getValueType(0), - N->getOpcode()); +static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) { + unsigned Opc = N.getOpcode(); + EVT VT = N.getValueType(); + unsigned NumElts = VT.getVectorNumElements(); + unsigned OrigEltSize = VT.getScalarSizeInBits(); + unsigned EltSize = OrigEltSize / 2; + MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts); - assert(N->getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR"); - EVT VT = N->getValueType(0); + if (VT.is128BitVector()) { + APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize); + if (DAG.MaskedValueIsZero(N, HiBits)) + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N); + } + + if (ISD::isExtOpcode(Opc)) + return addRequiredExtensionForVectorMULL( + N.getOperand(0), DAG, N.getOperand(0).getValueType(), VT, Opc); + + assert(Opc == ISD::BUILD_VECTOR && "expected BUILD_VECTOR"); SDLoc dl(N); - unsigned EltSize = VT.getScalarSizeInBits() / 2; - unsigned NumElts = VT.getVectorNumElements(); - MVT TruncVT = MVT::getIntegerVT(EltSize); SmallVector Ops; for (unsigned i = 0; i != NumElts; ++i) { - ConstantSDNode *C = cast(N->getOperand(i)); + ConstantSDNode *C = cast(N.getOperand(i)); const APInt &CInt = C->getAPIntValue(); // Element types smaller than 32 bits are not legal, so use i32 elements. // The values are implicitly truncated so sext vs. zext doesn't matter. Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32)); } - return DAG.getBuildVector(MVT::getVectorVT(TruncVT, NumElts), dl, Ops); + return DAG.getBuildVector(TruncVT, dl, Ops); } -static bool isSignExtended(SDNode *N, SelectionDAG &DAG) { - return N->getOpcode() == ISD::SIGN_EXTEND || - N->getOpcode() == ISD::ANY_EXTEND || +static bool isSignExtended(SDValue N, SelectionDAG &DAG) { + return N.getOpcode() == ISD::SIGN_EXTEND || + N.getOpcode() == ISD::ANY_EXTEND || isExtendedBUILD_VECTOR(N, DAG, true); } -static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) { - return N->getOpcode() == ISD::ZERO_EXTEND || - N->getOpcode() == ISD::ANY_EXTEND || +static bool isZeroExtended(SDValue N, SelectionDAG &DAG) { + return N.getOpcode() == ISD::ZERO_EXTEND || + N.getOpcode() == ISD::ANY_EXTEND || isExtendedBUILD_VECTOR(N, DAG, false); } -static bool isAddSubSExt(SDNode *N, SelectionDAG &DAG) { +static bool isAddSubSExt(SDValue N, SelectionDAG &DAG) { unsigned Opcode = N->getOpcode(); if (Opcode == ISD::ADD || Opcode == ISD::SUB) { - SDNode *N0 = N->getOperand(0).getNode(); - SDNode *N1 = N->getOperand(1).getNode(); + SDValue N0 = N.getOperand(0); + SDValue N1 = N.getOperand(1); return N0->hasOneUse() && N1->hasOneUse() && isSignExtended(N0, DAG) && isSignExtended(N1, DAG); } return false; } -static bool isAddSubZExt(SDNode *N, SelectionDAG &DAG) { +static bool isAddSubZExt(SDValue N, SelectionDAG &DAG) { unsigned Opcode = N->getOpcode(); if (Opcode == ISD::ADD || Opcode == ISD::SUB) { - SDNode *N0 = N->getOperand(0).getNode(); - SDNode *N1 = N->getOperand(1).getNode(); + SDValue N0 = N.getOperand(0); + SDValue N1 = N.getOperand(1); return N0->hasOneUse() && N1->hasOneUse() && isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG); } @@ -4550,7 +4557,7 @@ return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); } -static unsigned selectUmullSmull(SDNode *&N0, SDNode *&N1, SelectionDAG &DAG, +static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG, SDLoc DL, bool &IsMLA) { bool IsN0SExt = isSignExtended(N0, DAG); bool IsN1SExt = isSignExtended(N1, DAG); @@ -4569,12 +4576,12 @@ !isExtendedBUILD_VECTOR(N1, DAG, false)) { SDValue ZextOperand; if (IsN0ZExt) - ZextOperand = N0->getOperand(0); + ZextOperand = N0.getOperand(0); else - ZextOperand = N1->getOperand(0); + ZextOperand = N1.getOperand(0); if (DAG.SignBitIsZero(ZextOperand)) { - SDNode *NewSext = - DAG.getSExtOrTrunc(ZextOperand, DL, N0->getValueType(0)).getNode(); + SDValue NewSext = + DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType()); if (IsN0ZExt) N0 = NewSext; else @@ -4588,31 +4595,8 @@ EVT VT = N0->getValueType(0); APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(), VT.getScalarSizeInBits() / 2); - if (DAG.MaskedValueIsZero(SDValue(IsN0ZExt ? N1 : N0, 0), Mask)) { - EVT HalfVT; - switch (VT.getSimpleVT().SimpleTy) { - case MVT::v2i64: - HalfVT = MVT::v2i32; - break; - case MVT::v4i32: - HalfVT = MVT::v4i16; - break; - case MVT::v8i16: - HalfVT = MVT::v8i8; - break; - default: - return 0; - } - // Truncate and then extend the result. - SDValue NewExt = DAG.getNode(ISD::TRUNCATE, DL, HalfVT, - SDValue(IsN0ZExt ? N1 : N0, 0)); - NewExt = DAG.getZExtOrTrunc(NewExt, DL, VT); - if (IsN0ZExt) - N1 = NewExt.getNode(); - else - N0 = NewExt.getNode(); + if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask)) return AArch64ISD::UMULL; - } } if (!IsN1SExt && !IsN1ZExt) @@ -4647,18 +4631,18 @@ // that VMULL can be detected. Otherwise v2i64 multiplications are not legal. assert((VT.is128BitVector() || VT.is64BitVector()) && VT.isInteger() && "unexpected type for custom-lowering ISD::MUL"); - SDNode *N0 = Op.getOperand(0).getNode(); - SDNode *N1 = Op.getOperand(1).getNode(); + SDValue N0 = Op.getOperand(0); + SDValue N1 = Op.getOperand(1); bool isMLA = false; EVT OVT = VT; if (VT.is64BitVector()) { - if (N0->getOpcode() == ISD::EXTRACT_SUBVECTOR && - isNullConstant(N0->getOperand(1)) && - N1->getOpcode() == ISD::EXTRACT_SUBVECTOR && - isNullConstant(N1->getOperand(1))) { - N0 = N0->getOperand(0).getNode(); - N1 = N1->getOperand(0).getNode(); - VT = N0->getValueType(0); + if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR && + isNullConstant(N0.getOperand(1)) && + N1.getOpcode() == ISD::EXTRACT_SUBVECTOR && + isNullConstant(N1.getOperand(1))) { + N0 = N0.getOperand(0); + N1 = N1.getOperand(0); + VT = N0.getValueType(); } else { if (VT == MVT::v1i64) { if (Subtarget->hasSVE()) @@ -4702,12 +4686,12 @@ // Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during // isel lowering to take advantage of no-stall back to back s/umul + s/umla. // This is true for CPUs with accumulate forwarding such as Cortex-A53/A57 - SDValue N00 = skipExtensionForVectorMULL(N0->getOperand(0).getNode(), DAG); - SDValue N01 = skipExtensionForVectorMULL(N0->getOperand(1).getNode(), DAG); + SDValue N00 = skipExtensionForVectorMULL(N0.getOperand(0), DAG); + SDValue N01 = skipExtensionForVectorMULL(N0.getOperand(1), DAG); EVT Op1VT = Op1.getValueType(); return DAG.getNode( ISD::EXTRACT_SUBVECTOR, DL, OVT, - DAG.getNode(N0->getOpcode(), DL, VT, + DAG.getNode(N0.getOpcode(), DL, VT, DAG.getNode(NewOpc, DL, VT, DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1), DAG.getNode(NewOpc, DL, VT, @@ -16476,8 +16460,8 @@ if (TrailingZeroes) { // Conservatively do not lower to shift+add+shift if the mul might be // folded into smul or umul. - if (N0->hasOneUse() && (isSignExtended(N0.getNode(), DAG) || - isZeroExtended(N0.getNode(), DAG))) + if (N0->hasOneUse() && (isSignExtended(N0, DAG) || + isZeroExtended(N0, DAG))) return SDValue(); // Conservatively do not lower to shift+add+shift if the mul might be // folded into madd or msub.