diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4368,6 +4368,57 @@ return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); } +unsigned static selectUmullSmull(SDNode *&N0, SDNode *&N1, SelectionDAG &DAG, + SDLoc DL, bool &isMLA) { + + bool isN0SExt = isSignExtended(N0, DAG); + bool isN1SExt = isSignExtended(N1, DAG); + if (isN0SExt && isN1SExt) + return AArch64ISD::SMULL; + + bool isN0ZExt = isZeroExtended(N0, DAG); + bool isN1ZExt = isZeroExtended(N1, DAG); + + // Select SMULL if we can replace zext with sext. + if ((isN0SExt && isN1ZExt) || (isN0ZExt && isN1SExt)) { + SDValue ZextOperand; + if (isN0ZExt) + ZextOperand = N0->getOperand(0); + else + ZextOperand = N1->getOperand(0); + if (DAG.SignBitIsZero(ZextOperand)) { + SDNode *NewSext = + DAG.getSExtOrTrunc(ZextOperand, DL, N0->getValueType(0)).getNode(); + if (isN0ZExt) + N0 = NewSext; + else + N1 = NewSext; + return AArch64ISD::SMULL; + } + } + if (isN0ZExt && isN1ZExt) + return AArch64ISD::UMULL; + + if (!isN1SExt && !isN1ZExt) + return 0; + // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these + // into (s/zext A * s/zext C) + (s/zext B * s/zext C) + if (isN1SExt && isAddSubSExt(N0, DAG)) { + isMLA = true; + return AArch64ISD::SMULL; + } + if (isN1ZExt && isAddSubZExt(N0, DAG)) { + isMLA = true; + return AArch64ISD::UMULL; + } + if (isN0ZExt && isAddSubZExt(N1, DAG)) { + std::swap(N0, N1); + isMLA = true; + return AArch64ISD::UMULL; + } + return 0; +} + SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); @@ -4383,45 +4434,20 @@ "unexpected type for custom-lowering ISD::MUL"); SDNode *N0 = Op.getOperand(0).getNode(); SDNode *N1 = Op.getOperand(1).getNode(); - unsigned NewOpc = 0; + SDLoc DL(Op); bool isMLA = false; - bool isN0SExt = isSignExtended(N0, DAG); - bool isN1SExt = isSignExtended(N1, DAG); - if (isN0SExt && isN1SExt) - NewOpc = AArch64ISD::SMULL; - else { - bool isN0ZExt = isZeroExtended(N0, DAG); - bool isN1ZExt = isZeroExtended(N1, DAG); - if (isN0ZExt && isN1ZExt) - NewOpc = AArch64ISD::UMULL; - else if (isN1SExt || isN1ZExt) { - // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these - // into (s/zext A * s/zext C) + (s/zext B * s/zext C) - if (isN1SExt && isAddSubSExt(N0, DAG)) { - NewOpc = AArch64ISD::SMULL; - isMLA = true; - } else if (isN1ZExt && isAddSubZExt(N0, DAG)) { - NewOpc = AArch64ISD::UMULL; - isMLA = true; - } else if (isN0ZExt && isAddSubZExt(N1, DAG)) { - std::swap(N0, N1); - NewOpc = AArch64ISD::UMULL; - isMLA = true; - } - } + unsigned NewOpc = selectUmullSmull(N0, N1, DAG, DL, isMLA); - if (!NewOpc) { - if (VT == MVT::v2i64) - // Fall through to expand this. It is not legal. - return SDValue(); - else - // Other vector multiplications are legal. - return Op; - } + if (!NewOpc) { + if (VT == MVT::v2i64) + // Fall through to expand this. It is not legal. + return SDValue(); + else + // Other vector multiplications are legal. + return Op; } // Legalize to a S/UMULL instruction - SDLoc DL(Op); SDValue Op0; SDValue Op1 = skipExtensionForVectorMULL(N1, DAG); if (!isMLA) { diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll --- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll @@ -50,14 +50,10 @@ ; CHECK-LABEL: smull_zext_v8i8_v8i32: ; CHECK: // %bb.0: ; CHECK-NEXT: ldr d0, [x0] -; CHECK-NEXT: ldr q1, [x1] +; CHECK-NEXT: ldr q2, [x1] ; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v2.4s, v1.4h, #0 -; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0 -; CHECK-NEXT: ushll2 v3.4s, v0.8h, #0 -; CHECK-NEXT: ushll v0.4s, v0.4h, #0 -; CHECK-NEXT: mul v1.4s, v3.4s, v1.4s -; CHECK-NEXT: mul v0.4s, v0.4s, v2.4s +; CHECK-NEXT: smull2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: smull v0.4s, v0.4h, v2.4h ; CHECK-NEXT: ret %load.A = load <8 x i8>, <8 x i8>* %A %load.B = load <8 x i16>, <8 x i16>* %B @@ -70,15 +66,11 @@ define <8 x i32> @smull_zext_v8i8_v8i32_sext_first_operand(<8 x i16>* %A, <8 x i8>* %B) nounwind { ; CHECK-LABEL: smull_zext_v8i8_v8i32_sext_first_operand: ; CHECK: // %bb.0: -; CHECK-NEXT: ldr d1, [x1] -; CHECK-NEXT: ldr q0, [x0] -; CHECK-NEXT: ushll v1.8h, v1.8b, #0 -; CHECK-NEXT: sshll v2.4s, v0.4h, #0 -; CHECK-NEXT: sshll2 v0.4s, v0.8h, #0 -; CHECK-NEXT: ushll2 v3.4s, v1.8h, #0 -; CHECK-NEXT: ushll v4.4s, v1.4h, #0 -; CHECK-NEXT: mul v1.4s, v0.4s, v3.4s -; CHECK-NEXT: mul v0.4s, v2.4s, v4.4s +; CHECK-NEXT: ldr d0, [x1] +; CHECK-NEXT: ldr q2, [x0] +; CHECK-NEXT: ushll v0.8h, v0.8b, #0 +; CHECK-NEXT: smull2 v1.4s, v2.8h, v0.8h +; CHECK-NEXT: smull v0.4s, v2.4h, v0.4h ; CHECK-NEXT: ret %load.A = load <8 x i16>, <8 x i16>* %A %load.B = load <8 x i8>, <8 x i8>* %B @@ -116,9 +108,7 @@ ; CHECK-NEXT: ldr s0, [x0] ; CHECK-NEXT: ldr d1, [x1] ; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v1.4s, v1.4h, #0 -; CHECK-NEXT: ushll v0.4s, v0.4h, #0 -; CHECK-NEXT: mul v0.4s, v0.4s, v1.4s +; CHECK-NEXT: smull v0.4s, v0.4h, v1.4h ; CHECK-NEXT: ret %load.A = load <4 x i8>, <4 x i8>* %A %load.B = load <4 x i16>, <4 x i16>* %B @@ -156,16 +146,7 @@ ; CHECK-NEXT: ldr d0, [x0] ; CHECK-NEXT: ldr d1, [x1] ; CHECK-NEXT: bic v0.2s, #128, lsl #24 -; CHECK-NEXT: sshll v1.2d, v1.2s, #0 -; CHECK-NEXT: ushll v0.2d, v0.2s, #0 -; CHECK-NEXT: fmov x9, d1 -; CHECK-NEXT: fmov x10, d0 -; CHECK-NEXT: mov x8, v1.d[1] -; CHECK-NEXT: mov x11, v0.d[1] -; CHECK-NEXT: mul x9, x10, x9 -; CHECK-NEXT: mul x8, x11, x8 -; CHECK-NEXT: fmov d0, x9 -; CHECK-NEXT: mov v0.d[1], x8 +; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEXT: ret %load.A = load <2 x i32>, <2 x i32>* %A %and.A = and <2 x i32> %load.A,