Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -3036,9 +3036,28 @@ // Legalization utility functions // + /// Expand a MUL or [US]MUL_LOHI of n-bit values into two or four nodes, + /// respectively, each computing an n/2-bit part of the result. + /// \param Result A vector that will be filled with the parts of the result + /// in little-endian order. + /// \param HalfVT The value type to use for the result nodes. + /// \param OnlyLegalOrCustom Only legal or custom instructions are used. + /// \param LL Low bits of the LHS of the MUL. You can use this parameter + /// if you want to control how low bits are extracted from the LHS. + /// \param LH High bits of the LHS of the MUL. See LL for meaning. + /// \param RL Low bits of the RHS of the MUL. See LL for meaning + /// \param RH High bits of the RHS of the MUL. See LL for meaning. + /// \returns true if the node has been expanded, false if it has not + bool expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl, SDValue LHS, + SDValue RHS, SmallVectorImpl &Result, EVT HiLoVT, + SelectionDAG &DAG, bool OnlyLegalOrCustom, + SDValue LL = SDValue(), SDValue LH = SDValue(), + SDValue RL = SDValue(), SDValue RH = SDValue()) const; + /// Expand a MUL into two nodes. One that computes the high bits of /// the result and one that computes the low bits. /// \param HiLoVT The value type to use for the Lo and Hi nodes. + /// \param OnlyLegalOrCustom Only legal or custom instructions are used. /// \param LL Low bits of the LHS of the MUL. You can use this parameter /// if you want to control how low bits are extracted from the LHS. /// \param LH High bits of the LHS of the MUL. See LL for meaning. @@ -3046,9 +3065,9 @@ /// \param RH High bits of the RHS of the MUL. See LL for meaning. /// \returns true if the node has been expanded. false if it has not bool expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, - SelectionDAG &DAG, SDValue LL = SDValue(), - SDValue LH = SDValue(), SDValue RL = SDValue(), - SDValue RH = SDValue()) const; + SelectionDAG &DAG, bool OnlyLegalOrCustom, + SDValue LL = SDValue(), SDValue LH = SDValue(), + SDValue RL = SDValue(), SDValue RH = SDValue()) const; /// Expand float(f32) to SINT(i64) conversion /// \param N Node to expand Index: lib/CodeGen/SelectionDAG/LegalizeDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -3312,17 +3312,48 @@ } case ISD::MULHU: case ISD::MULHS: { - unsigned ExpandOpcode = Node->getOpcode() == ISD::MULHU ? ISD::UMUL_LOHI : - ISD::SMUL_LOHI; + unsigned ExpandOpcode = + Node->getOpcode() == ISD::MULHU ? ISD::UMUL_LOHI : ISD::SMUL_LOHI; EVT VT = Node->getValueType(0); SDVTList VTs = DAG.getVTList(VT, VT); - assert(TLI.isOperationLegalOrCustom(ExpandOpcode, VT) && - "If this wasn't legal, it shouldn't have been created!"); + Tmp1 = DAG.getNode(ExpandOpcode, dl, VTs, Node->getOperand(0), Node->getOperand(1)); Results.push_back(Tmp1.getValue(1)); break; } + case ISD::UMUL_LOHI: + case ISD::SMUL_LOHI: { + SDValue LHS = Node->getOperand(0); + SDValue RHS = Node->getOperand(1); + MVT VT = LHS.getSimpleValueType(); + unsigned MULHOpcode = + Node->getOpcode() == ISD::UMUL_LOHI ? ISD::MULHU : ISD::MULHS; + + if (TLI.isOperationLegalOrCustom(MULHOpcode, VT)) { + Results.push_back(DAG.getNode(ISD::MUL, dl, VT, LHS, RHS)); + Results.push_back(DAG.getNode(MULHOpcode, dl, VT, LHS, RHS)); + break; + } + + SmallVector Halves; + EVT HalfType = EVT(VT).getHalfSizedIntegerVT(*DAG.getContext()); + assert(TLI.isTypeLegal(HalfType)); + if (TLI.expandMUL_LOHI(Node->getOpcode(), VT, Node, LHS, RHS, Halves, + HalfType, DAG, false)) { + for (unsigned i = 0; i < 2; ++i) { + SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Halves[2 * i]); + SDValue Hi = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Halves[2 * i + 1]); + SDValue Shift = DAG.getConstant( + HalfType.getScalarSizeInBits(), dl, + TLI.getShiftAmountTy(HalfType, DAG.getDataLayout())); + Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift); + Results.push_back(DAG.getNode(ISD::OR, dl, VT, Lo, Hi)); + } + break; + } + break; + } case ISD::MUL: { EVT VT = Node->getValueType(0); SDVTList VTs = DAG.getVTList(VT, VT); @@ -3357,7 +3388,7 @@ TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND, VT) && TLI.isOperationLegalOrCustom(ISD::SHL, VT) && TLI.isOperationLegalOrCustom(ISD::OR, VT) && - TLI.expandMUL(Node, Lo, Hi, HalfType, DAG)) { + TLI.expandMUL(Node, Lo, Hi, HalfType, DAG, true)) { Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo); Hi = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Hi); SDValue Shift = @@ -4185,6 +4216,24 @@ Results.push_back(DAG.getNode(TruncOp, dl, OVT, Tmp1)); break; } + case ISD::UMUL_LOHI: + case ISD::SMUL_LOHI: { + // Promote to a multiply in a wider integer type. + unsigned ExtOp = Node->getOpcode() == ISD::UMUL_LOHI ? ISD::ZERO_EXTEND + : ISD::SIGN_EXTEND; + Tmp1 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(0)); + Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1)); + Tmp1 = DAG.getNode(ISD::MUL, dl, NVT, Tmp1, Tmp2); + + auto &DL = DAG.getDataLayout(); + unsigned OriginalSize = OVT.getScalarSizeInBits(); + Tmp2 = DAG.getNode( + ISD::SRL, dl, NVT, Tmp1, + DAG.getConstant(OriginalSize, dl, TLI.getScalarShiftAmountTy(DL, NVT))); + Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2)); + break; + } case ISD::SELECT: { unsigned ExtOp, TruncOp; if (Node->getValueType(0).isVector() || Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -2189,7 +2189,7 @@ GetExpandedInteger(N->getOperand(0), LL, LH); GetExpandedInteger(N->getOperand(1), RL, RH); - if (TLI.expandMUL(N, Lo, Hi, NVT, DAG, LL, LH, RL, RH)) + if (TLI.expandMUL(N, Lo, Hi, NVT, DAG, true, LL, LH, RL, RH)) return; // If nothing else, we can make a libcall. Index: lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -333,6 +333,8 @@ case ISD::SMAX: case ISD::UMIN: case ISD::UMAX: + case ISD::SMUL_LOHI: + case ISD::UMUL_LOHI: QueryType = Node->getValueType(0); break; case ISD::FP_ROUND_INREG: Index: lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -3079,24 +3079,29 @@ // Legalization Utilities //===----------------------------------------------------------------------===// -bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, - SelectionDAG &DAG, SDValue LL, SDValue LH, - SDValue RL, SDValue RH) const { - EVT VT = N->getValueType(0); - SDLoc dl(N); - - bool HasMULHS = isOperationLegalOrCustom(ISD::MULHS, HiLoVT); - bool HasMULHU = isOperationLegalOrCustom(ISD::MULHU, HiLoVT); - bool HasSMUL_LOHI = isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT); - bool HasUMUL_LOHI = isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT); +bool TargetLowering::expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl, + SDValue LHS, SDValue RHS, + SmallVectorImpl &Result, + EVT HiLoVT, SelectionDAG &DAG, + bool OnlyLegalOrCustom, SDValue LL, + SDValue LH, SDValue RL, SDValue RH) const { + assert(Opcode == ISD::MUL || Opcode == ISD::UMUL_LOHI || + Opcode == ISD::SMUL_LOHI); + + bool HasMULHS = + !OnlyLegalOrCustom || isOperationLegalOrCustom(ISD::MULHS, HiLoVT); + bool HasMULHU = + !OnlyLegalOrCustom || isOperationLegalOrCustom(ISD::MULHU, HiLoVT); + bool HasSMUL_LOHI = + !OnlyLegalOrCustom || isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT); + bool HasUMUL_LOHI = + !OnlyLegalOrCustom || isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT); if (!HasMULHU && !HasMULHS && !HasUMUL_LOHI && !HasSMUL_LOHI) return false; unsigned OuterBitSize = VT.getScalarSizeInBits(); unsigned InnerBitSize = HiLoVT.getScalarSizeInBits(); - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); unsigned LHSSB = DAG.ComputeNumSignBits(LHS); unsigned RHSSB = DAG.ComputeNumSignBits(RHS); @@ -3120,6 +3125,8 @@ return false; }; + SDValue Lo, Hi; + if (!LL.getNode() && !RL.getNode() && isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) { LL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, LHS); @@ -3130,16 +3137,30 @@ return false; APInt HighMask = APInt::getHighBitsSet(OuterBitSize, InnerBitSize); - if (DAG.MaskedValueIsZero(N->getOperand(0), HighMask) && - DAG.MaskedValueIsZero(N->getOperand(1), HighMask)) { + if (!VT.isVector() && DAG.MaskedValueIsZero(LHS, HighMask) && + DAG.MaskedValueIsZero(RHS, HighMask)) { // The inputs are both zero-extended. - if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) + if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) { + Result.push_back(Lo); + Result.push_back(Hi); + if (Opcode != ISD::MUL) { + SDValue Zero = DAG.getConstant(0, dl, HiLoVT); + Result.push_back(Zero); + Result.push_back(Zero); + } return true; + } } - if (LHSSB > InnerBitSize && RHSSB > InnerBitSize) { + + if (!VT.isVector() && Opcode == ISD::MUL && LHSSB > InnerBitSize && + RHSSB > InnerBitSize) { // The input values are both sign-extended. - if (MakeMUL_LOHI(LL, RL, Lo, Hi, true)) + // TODO non-MUL case? + if (MakeMUL_LOHI(LL, RL, Lo, Hi, true)) { + Result.push_back(Lo); + Result.push_back(Hi); return true; + } } auto &DL = DAG.getDataLayout(); @@ -3158,15 +3179,84 @@ if (!LH.getNode()) return false; - if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) { + if (!MakeMUL_LOHI(LL, RL, Lo, Hi, false)) + return false; + + Result.push_back(Lo); + + if (Opcode == ISD::MUL) { RH = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RH); LH = DAG.getNode(ISD::MUL, dl, HiLoVT, LH, RL); Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, RH); Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, LH); + Result.push_back(Hi); return true; } - return false; + // Compute the full width result. + auto Merge = [&](SDValue Lo, SDValue Hi) -> SDValue { + Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo); + Hi = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi); + Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift); + return DAG.getNode(ISD::OR, dl, VT, Lo, Hi); + }; + + SDValue Next = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi); + if (!MakeMUL_LOHI(LL, RH, Lo, Hi, false)) + return false; + + // This is effectively the add part of a multiply-add of half-sized operands, + // so it cannot overflow. + Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi)); + + if (!MakeMUL_LOHI(LH, RL, Lo, Hi, false)) + return false; + + Next = DAG.getNode(ISD::ADDC, dl, DAG.getVTList(VT, MVT::Glue), Next, + Merge(Lo, Hi)); + + SDValue Carry = Next.getValue(1); + Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next)); + Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift); + + if (!MakeMUL_LOHI(LH, RH, Lo, Hi, Opcode == ISD::SMUL_LOHI)) + return false; + + SDValue Zero = DAG.getConstant(0, dl, HiLoVT); + Hi = DAG.getNode(ISD::ADDE, dl, DAG.getVTList(HiLoVT, MVT::Glue), Hi, Zero, + Carry); + Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi)); + + if (Opcode == ISD::SMUL_LOHI) { + SDValue NextSub = DAG.getNode(ISD::SUB, dl, VT, Next, + DAG.getNode(ISD::ZERO_EXTEND, dl, VT, RL)); + Next = DAG.getSelectCC(dl, LH, Zero, NextSub, Next, ISD::SETLT); + + NextSub = DAG.getNode(ISD::SUB, dl, VT, Next, + DAG.getNode(ISD::ZERO_EXTEND, dl, VT, LL)); + Next = DAG.getSelectCC(dl, RH, Zero, NextSub, Next, ISD::SETLT); + } + + Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next)); + Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift); + Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next)); + return true; +} + +bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, + SelectionDAG &DAG, bool OnlyLegalOrCustom, + SDValue LL, SDValue LH, SDValue RL, + SDValue RH) const { + SmallVector Result; + bool Ok = expandMUL_LOHI(N->getOpcode(), N->getValueType(0), N, + N->getOperand(0), N->getOperand(1), Result, HiLoVT, + DAG, OnlyLegalOrCustom, LL, LH, RL, RH); + if (Ok) { + assert(Result.size() == 2); + Lo = Result[0]; + Hi = Result[1]; + } + return Ok; } bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result,