diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4377,6 +4377,39 @@ return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); } +static unsigned selectUmullSmull(SDNode *&N0, SDNode *&N1, SelectionDAG &DAG, + bool &IsMLA) { + bool IsN0SExt = isSignExtended(N0, DAG); + bool IsN1SExt = isSignExtended(N1, DAG); + if (IsN0SExt && IsN1SExt) + return AArch64ISD::SMULL; + + bool IsN0ZExt = isZeroExtended(N0, DAG); + bool IsN1ZExt = isZeroExtended(N1, DAG); + + if (IsN0ZExt && IsN1ZExt) + return AArch64ISD::UMULL; + + if (!IsN1SExt && !IsN1ZExt) + return 0; + // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these + // into (s/zext A * s/zext C) + (s/zext B * s/zext C) + if (IsN1SExt && isAddSubSExt(N0, DAG)) { + IsMLA = true; + return AArch64ISD::SMULL; + } + if (IsN1ZExt && isAddSubZExt(N0, DAG)) { + IsMLA = true; + return AArch64ISD::UMULL; + } + if (IsN0ZExt && isAddSubZExt(N1, DAG)) { + std::swap(N0, N1); + IsMLA = true; + return AArch64ISD::UMULL; + } + return 0; +} + SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); @@ -4392,41 +4425,16 @@ "unexpected type for custom-lowering ISD::MUL"); SDNode *N0 = Op.getOperand(0).getNode(); SDNode *N1 = Op.getOperand(1).getNode(); - unsigned NewOpc = 0; bool isMLA = false; - bool isN0SExt = isSignExtended(N0, DAG); - bool isN1SExt = isSignExtended(N1, DAG); - if (isN0SExt && isN1SExt) - NewOpc = AArch64ISD::SMULL; - else { - bool isN0ZExt = isZeroExtended(N0, DAG); - bool isN1ZExt = isZeroExtended(N1, DAG); - if (isN0ZExt && isN1ZExt) - NewOpc = AArch64ISD::UMULL; - else if (isN1SExt || isN1ZExt) { - // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these - // into (s/zext A * s/zext C) + (s/zext B * s/zext C) - if (isN1SExt && isAddSubSExt(N0, DAG)) { - NewOpc = AArch64ISD::SMULL; - isMLA = true; - } else if (isN1ZExt && isAddSubZExt(N0, DAG)) { - NewOpc = AArch64ISD::UMULL; - isMLA = true; - } else if (isN0ZExt && isAddSubZExt(N1, DAG)) { - std::swap(N0, N1); - NewOpc = AArch64ISD::UMULL; - isMLA = true; - } - } + unsigned NewOpc = selectUmullSmull(N0, N1, DAG, isMLA); - if (!NewOpc) { - if (VT == MVT::v2i64) - // Fall through to expand this. It is not legal. - return SDValue(); - else - // Other vector multiplications are legal. - return Op; - } + if (!NewOpc) { + if (VT == MVT::v2i64) + // Fall through to expand this. It is not legal. + return SDValue(); + else + // Other vector multiplications are legal. + return Op; } // Legalize to a S/UMULL instruction