Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -3003,9 +3003,28 @@ // Legalization utility functions // + /// Expand a MUL or [US]MUL_LOHI of n-bit values into two or four nodes, + /// respectively, each computing an n/2-bit part of the result. + /// \param Result A vector that will be filled with the parts of the result + /// in little-endian order. + /// \param HalfVT The value type to use for the result nodes. + /// \param OnlyLegalOrCustom Only legal or custom instructions are used. + /// \param LL Low bits of the LHS of the MUL. You can use this parameter + /// if you want to control how low bits are extracted from the LHS. + /// \param LH High bits of the LHS of the MUL. See LL for meaning. + /// \param RL Low bits of the RHS of the MUL. See LL for meaning + /// \param RH High bits of the RHS of the MUL. See LL for meaning. + /// \returns true if the node has been expanded, false if it has not + bool expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl, SDValue LHS, + SDValue RHS, SmallVectorImpl &Result, EVT HalfVT, + SelectionDAG &DAG, bool OnlyLegalOrCustom, + SDValue LL = SDValue(), SDValue LH = SDValue(), + SDValue RL = SDValue(), SDValue RH = SDValue()) const; + /// Expand a MUL into two nodes. One that computes the high bits of /// the result and one that computes the low bits. /// \param HiLoVT The value type to use for the Lo and Hi nodes. + /// \param OnlyLegalOrCustom Only legal or custom instructions are used. /// \param LL Low bits of the LHS of the MUL. You can use this parameter /// if you want to control how low bits are extracted from the LHS. /// \param LH High bits of the LHS of the MUL. See LL for meaning. @@ -3013,9 +3032,9 @@ /// \param RH High bits of the RHS of the MUL. See LL for meaning. /// \returns true if the node has been expanded. false if it has not bool expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, - SelectionDAG &DAG, SDValue LL = SDValue(), - SDValue LH = SDValue(), SDValue RL = SDValue(), - SDValue RH = SDValue()) const; + SelectionDAG &DAG, bool OnlyLegalOrCustom, + SDValue LL = SDValue(), SDValue LH = SDValue(), + SDValue RL = SDValue(), SDValue RH = SDValue()) const; /// Expand float(f32) to SINT(i64) conversion /// \param N Node to expand Index: lib/CodeGen/SelectionDAG/LegalizeDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -3312,17 +3312,55 @@ } case ISD::MULHU: case ISD::MULHS: { - unsigned ExpandOpcode = Node->getOpcode() == ISD::MULHU ? ISD::UMUL_LOHI : - ISD::SMUL_LOHI; + unsigned ExpandOpcode = + Node->getOpcode() == ISD::MULHU ? ISD::UMUL_LOHI : ISD::SMUL_LOHI; EVT VT = Node->getValueType(0); SDVTList VTs = DAG.getVTList(VT, VT); - assert(TLI.isOperationLegalOrCustom(ExpandOpcode, VT) && - "If this wasn't legal, it shouldn't have been created!"); + Tmp1 = DAG.getNode(ExpandOpcode, dl, VTs, Node->getOperand(0), Node->getOperand(1)); Results.push_back(Tmp1.getValue(1)); break; } + case ISD::UMUL_LOHI: + case ISD::SMUL_LOHI: { + SDValue LHS = Node->getOperand(0); + SDValue RHS = Node->getOperand(1); + MVT VT = LHS.getSimpleValueType(); + unsigned MULHOpcode = + Node->getOpcode() == ISD::UMUL_LOHI ? ISD::MULHU : ISD::MULHS; + + if (TLI.isOperationLegalOrCustom(MULHOpcode, VT)) { + Results.push_back(DAG.getNode(ISD::MUL, dl, VT, LHS, RHS)); + Results.push_back(DAG.getNode(MULHOpcode, dl, VT, LHS, RHS)); + break; + } + + SmallVector Halves; + EVT HalfType; + if (!VT.isVector()) { + HalfType = EVT(VT).getHalfSizedIntegerVT(*DAG.getContext()); + } else { + HalfType = EVT::getVectorVT( + *DAG.getContext(), + EVT(VT.getScalarType()).getHalfSizedIntegerVT(*DAG.getContext()), + VT.getVectorNumElements()); + } + if (TLI.expandMUL_LOHI(Node->getOpcode(), VT, Node, LHS, RHS, Halves, + HalfType, DAG, false)) { + for (unsigned i = 0; i < 2; ++i) { + SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Halves[2 * i]); + SDValue Hi = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Halves[2 * i + 1]); + SDValue Shift = DAG.getConstant( + HalfType.getScalarSizeInBits(), dl, + TLI.getShiftAmountTy(HalfType, DAG.getDataLayout())); + Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift); + Results.push_back(DAG.getNode(ISD::OR, dl, VT, Lo, Hi)); + } + break; + } + break; + } case ISD::MUL: { EVT VT = Node->getValueType(0); SDVTList VTs = DAG.getVTList(VT, VT); @@ -3357,7 +3395,7 @@ TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND, VT) && TLI.isOperationLegalOrCustom(ISD::SHL, VT) && TLI.isOperationLegalOrCustom(ISD::OR, VT) && - TLI.expandMUL(Node, Lo, Hi, HalfType, DAG)) { + TLI.expandMUL(Node, Lo, Hi, HalfType, DAG, true)) { Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo); Hi = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Hi); SDValue Shift = @@ -4168,6 +4206,24 @@ Results.push_back(DAG.getNode(TruncOp, dl, OVT, Tmp1)); break; } + case ISD::UMUL_LOHI: + case ISD::SMUL_LOHI: { + // Promote to a multiply in a wider integer type. + unsigned ExtOp = Node->getOpcode() == ISD::UMUL_LOHI ? ISD::ZERO_EXTEND + : ISD::SIGN_EXTEND; + Tmp1 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(0)); + Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1)); + Tmp1 = DAG.getNode(ISD::MUL, dl, NVT, Tmp1, Tmp2); + + auto &DL = DAG.getDataLayout(); + unsigned OriginalSize = OVT.getScalarSizeInBits(); + Tmp2 = DAG.getNode( + ISD::SRL, dl, NVT, Tmp1, + DAG.getConstant(OriginalSize, dl, TLI.getScalarShiftAmountTy(DL, NVT))); + Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2)); + break; + } case ISD::SELECT: { unsigned ExtOp, TruncOp; if (Node->getValueType(0).isVector() || Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -2189,7 +2189,7 @@ GetExpandedInteger(N->getOperand(0), LL, LH); GetExpandedInteger(N->getOperand(1), RL, RH); - if (TLI.expandMUL(N, Lo, Hi, NVT, DAG, LL, LH, RL, RH)) + if (TLI.expandMUL(N, Lo, Hi, NVT, DAG, true, LL, LH, RL, RH)) return; // If nothing else, we can make a libcall. Index: lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -332,6 +332,8 @@ case ISD::SMAX: case ISD::UMIN: case ISD::UMAX: + case ISD::SMUL_LOHI: + case ISD::UMUL_LOHI: QueryType = Node->getValueType(0); break; case ISD::FP_ROUND_INREG: Index: lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -3068,108 +3068,197 @@ // Legalization Utilities //===----------------------------------------------------------------------===// -bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, - SelectionDAG &DAG, SDValue LL, SDValue LH, - SDValue RL, SDValue RH) const { - EVT VT = N->getValueType(0); - SDLoc dl(N); +bool TargetLowering::expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl, + SDValue LHS, SDValue RHS, + SmallVectorImpl &Result, + EVT HalfVT, SelectionDAG &DAG, + bool OnlyLegalOrCustom, SDValue LL, + SDValue LH, SDValue RL, SDValue RH) const { + assert(Opcode == ISD::MUL || Opcode == ISD::UMUL_LOHI || + Opcode == ISD::SMUL_LOHI); + + bool HasMULHS = + !OnlyLegalOrCustom || isOperationLegalOrCustom(ISD::MULHS, HalfVT); + bool HasMULHU = + !OnlyLegalOrCustom || isOperationLegalOrCustom(ISD::MULHU, HalfVT); + bool HasSMUL_LOHI = + !OnlyLegalOrCustom || isOperationLegalOrCustom(ISD::SMUL_LOHI, HalfVT); + bool HasUMUL_LOHI = + !OnlyLegalOrCustom || isOperationLegalOrCustom(ISD::UMUL_LOHI, HalfVT); + unsigned OuterBitSize = VT.getScalarSizeInBits(); + unsigned InnerBitSize = HalfVT.getScalarSizeInBits(); + unsigned LHSSB = DAG.ComputeNumSignBits(LHS); + unsigned RHSSB = DAG.ComputeNumSignBits(RHS); + + // LL, LH, RL, and RH must be either all NULL or all set to a value. + assert((LL.getNode() && LH.getNode() && RL.getNode() && RH.getNode()) || + (!LL.getNode() && !LH.getNode() && !RL.getNode() && !RH.getNode())); + + if (!HasMULHS && !HasMULHU && !HasSMUL_LOHI && !HasUMUL_LOHI) + return false; - bool HasMULHS = isOperationLegalOrCustom(ISD::MULHS, HiLoVT); - bool HasMULHU = isOperationLegalOrCustom(ISD::MULHU, HiLoVT); - bool HasSMUL_LOHI = isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT); - bool HasUMUL_LOHI = isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT); - if (HasMULHU || HasMULHS || HasUMUL_LOHI || HasSMUL_LOHI) { - unsigned OuterBitSize = VT.getSizeInBits(); - unsigned InnerBitSize = HiLoVT.getSizeInBits(); - unsigned LHSSB = DAG.ComputeNumSignBits(N->getOperand(0)); - unsigned RHSSB = DAG.ComputeNumSignBits(N->getOperand(1)); - - // LL, LH, RL, and RH must be either all NULL or all set to a value. - assert((LL.getNode() && LH.getNode() && RL.getNode() && RH.getNode()) || - (!LL.getNode() && !LH.getNode() && !RL.getNode() && !RH.getNode())); - - if (!LL.getNode() && !RL.getNode() && - isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) { - LL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, N->getOperand(0)); - RL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, N->getOperand(1)); + SDVTList VTs = DAG.getVTList(HalfVT, HalfVT); + auto MakeUMUL_LOHI = [&](SDValue L, SDValue R, SDValue &Lo, + SDValue &Hi) -> bool { + if (HasUMUL_LOHI) { + Lo = DAG.getNode(ISD::UMUL_LOHI, dl, VTs, L, R); + Hi = SDValue(Lo.getNode(), 1); + return true; } - - if (!LL.getNode()) - return false; - - APInt HighMask = APInt::getHighBitsSet(OuterBitSize, InnerBitSize); - if (DAG.MaskedValueIsZero(N->getOperand(0), HighMask) && - DAG.MaskedValueIsZero(N->getOperand(1), HighMask)) { - // The inputs are both zero-extended. - if (HasUMUL_LOHI) { - // We can emit a umul_lohi. - Lo = DAG.getNode(ISD::UMUL_LOHI, dl, DAG.getVTList(HiLoVT, HiLoVT), LL, - RL); - Hi = SDValue(Lo.getNode(), 1); - return true; - } - if (HasMULHU) { - // We can emit a mulhu+mul. - Lo = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RL); - Hi = DAG.getNode(ISD::MULHU, dl, HiLoVT, LL, RL); - return true; - } + if (HasMULHU) { + Lo = DAG.getNode(ISD::MUL, dl, HalfVT, L, R); + Hi = DAG.getNode(ISD::MULHU, dl, HalfVT, L, R); + return true; } - if (LHSSB > InnerBitSize && RHSSB > InnerBitSize) { - // The input values are both sign-extended. - if (HasSMUL_LOHI) { - // We can emit a smul_lohi. - Lo = DAG.getNode(ISD::SMUL_LOHI, dl, DAG.getVTList(HiLoVT, HiLoVT), LL, - RL); - Hi = SDValue(Lo.getNode(), 1); - return true; - } - if (HasMULHS) { - // We can emit a mulhs+mul. - Lo = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RL); - Hi = DAG.getNode(ISD::MULHS, dl, HiLoVT, LL, RL); - return true; - } + return false; + }; + auto MakeSMUL_LOHI = [&](SDValue L, SDValue R, SDValue &Lo, + SDValue &Hi) -> bool { + if (HasSMUL_LOHI) { + Lo = DAG.getNode(ISD::SMUL_LOHI, dl, VTs, L, R); + Hi = SDValue(Lo.getNode(), 1); + return true; } - - if (!LH.getNode() && !RH.getNode() && - isOperationLegalOrCustom(ISD::SRL, VT) && - isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) { - auto &DL = DAG.getDataLayout(); - unsigned ShiftAmt = VT.getSizeInBits() - HiLoVT.getSizeInBits(); - SDValue Shift = DAG.getConstant(ShiftAmt, dl, getShiftAmountTy(VT, DL)); - LH = DAG.getNode(ISD::SRL, dl, VT, N->getOperand(0), Shift); - LH = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, LH); - RH = DAG.getNode(ISD::SRL, dl, VT, N->getOperand(1), Shift); - RH = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, RH); + if (HasMULHS) { + Lo = DAG.getNode(ISD::MUL, dl, HalfVT, L, R); + Hi = DAG.getNode(ISD::MULHS, dl, HalfVT, L, R); + return true; } + return false; + }; - if (!LH.getNode()) - return false; + SDValue Lo, Hi; - if (HasUMUL_LOHI) { - // Lo,Hi = umul LHS, RHS. - SDValue UMulLOHI = DAG.getNode(ISD::UMUL_LOHI, dl, - DAG.getVTList(HiLoVT, HiLoVT), LL, RL); - Lo = UMulLOHI; - Hi = UMulLOHI.getValue(1); - RH = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RH); - LH = DAG.getNode(ISD::MUL, dl, HiLoVT, LH, RL); - Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, RH); - Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, LH); + if (!LL.getNode() && !RL.getNode() && + isOperationLegalOrCustom(ISD::TRUNCATE, HalfVT)) { + LL = DAG.getNode(ISD::TRUNCATE, dl, HalfVT, LHS); + RL = DAG.getNode(ISD::TRUNCATE, dl, HalfVT, RHS); + } + + if (!LL.getNode()) + return false; + + APInt HighMask = APInt::getHighBitsSet(OuterBitSize, InnerBitSize); + if (!VT.isVector() && DAG.MaskedValueIsZero(LHS, HighMask) && + DAG.MaskedValueIsZero(RHS, HighMask)) { + // The inputs are both zero-extended. + if (MakeUMUL_LOHI(LL, RL, Lo, Hi)) { + Result.push_back(Lo); + Result.push_back(Hi); + if (Opcode != ISD::MUL) { + SDValue Zero = DAG.getConstant(0, dl, HalfVT); + Result.push_back(Zero); + Result.push_back(Zero); + } return true; } - if (HasMULHU) { - Lo = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RL); - Hi = DAG.getNode(ISD::MULHU, dl, HiLoVT, LL, RL); - RH = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RH); - LH = DAG.getNode(ISD::MUL, dl, HiLoVT, LH, RL); - Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, RH); - Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, LH); + } + + if (!VT.isVector() && Opcode == ISD::MUL && LHSSB > InnerBitSize && + RHSSB > InnerBitSize) { + // The input values are both sign-extended. + // TODO non-MUL case? + if (MakeSMUL_LOHI(LL, RL, Lo, Hi)) { + Result.push_back(Lo); + Result.push_back(Hi); return true; } } - return false; + + auto &DL = DAG.getDataLayout(); + SDValue Shift = DAG.getConstant(OuterBitSize - InnerBitSize, dl, + getShiftAmountTy(VT, DL)); + + if (!LH.getNode() && !RH.getNode() && + isOperationLegalOrCustom(ISD::SRL, VT) && + isOperationLegalOrCustom(ISD::TRUNCATE, HalfVT)) { + LH = DAG.getNode(ISD::SRL, dl, VT, LHS, Shift); + LH = DAG.getNode(ISD::TRUNCATE, dl, HalfVT, LH); + RH = DAG.getNode(ISD::SRL, dl, VT, RHS, Shift); + RH = DAG.getNode(ISD::TRUNCATE, dl, HalfVT, RH); + } + + if (!LH.getNode()) + return false; + + if (!MakeUMUL_LOHI(LL, RL, Lo, Hi)) + return false; + + Result.push_back(Lo); + + if (Opcode == ISD::MUL) { + RH = DAG.getNode(ISD::MUL, dl, HalfVT, LL, RH); + LH = DAG.getNode(ISD::MUL, dl, HalfVT, LH, RL); + Hi = DAG.getNode(ISD::ADD, dl, HalfVT, Hi, RH); + Hi = DAG.getNode(ISD::ADD, dl, HalfVT, Hi, LH); + Result.push_back(Hi); + return true; + } + + // Compute the full width result. + auto Merge = [&](SDValue Lo, SDValue Hi) -> SDValue { + Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo); + Hi = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi); + Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift); + return DAG.getNode(ISD::OR, dl, VT, Lo, Hi); + }; + + SDValue Next = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi); + if (!MakeUMUL_LOHI(LL, RH, Lo, Hi)) + return false; + + // This is effectively the add part of a multiply-add of half-sized operands, + // so it cannot overflow. + Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi)); + + if (!MakeUMUL_LOHI(LH, RL, Lo, Hi)) + return false; + + Next = DAG.getNode(ISD::ADDC, dl, DAG.getVTList(VT, MVT::Glue), Next, + Merge(Lo, Hi)); + + SDValue Carry = Next.getValue(1); + Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Next)); + Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift); + + if (!(Opcode == ISD::UMUL_LOHI ? MakeUMUL_LOHI(LH, RH, Lo, Hi) + : MakeSMUL_LOHI(LH, RH, Lo, Hi))) + return false; + + SDValue Zero = DAG.getConstant(0, dl, HalfVT); + Hi = DAG.getNode(ISD::ADDE, dl, DAG.getVTList(HalfVT, MVT::Glue), Hi, Zero, + Carry); + Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi)); + + if (Opcode == ISD::SMUL_LOHI) { + SDValue NextSub = DAG.getNode(ISD::SUB, dl, VT, Next, + DAG.getNode(ISD::ZERO_EXTEND, dl, VT, RL)); + Next = DAG.getSelectCC(dl, LH, Zero, NextSub, Next, ISD::SETLT); + + NextSub = DAG.getNode(ISD::SUB, dl, VT, Next, + DAG.getNode(ISD::ZERO_EXTEND, dl, VT, LL)); + Next = DAG.getSelectCC(dl, RH, Zero, NextSub, Next, ISD::SETLT); + } + + Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Next)); + Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift); + Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Next)); + return true; +} + +bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, + SelectionDAG &DAG, bool OnlyLegalOrCustom, + SDValue LL, SDValue LH, SDValue RL, + SDValue RH) const { + SmallVector Result; + bool Ok = expandMUL_LOHI(N->getOpcode(), N->getValueType(0), N, + N->getOperand(0), N->getOperand(1), Result, HiLoVT, + DAG, OnlyLegalOrCustom, LL, LH, RL, RH); + if (Result.size() >= 2) { + Lo = Result[0]; + Hi = Result[1]; + } + return Ok; } bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result,