diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4392,7 +4392,28 @@ else { bool isN0ZExt = isZeroExtended(N0, DAG); bool isN1ZExt = isZeroExtended(N1, DAG); - if (isN0ZExt && isN1ZExt) + // Replace zext instruction with sext if the last bit is zero to select + // smull. + if ((isN0SExt && isN1ZExt) || (isN0ZExt && isN1SExt)) { + SDLoc DL(Op); + SDValue ZextOperand; + if (isN0ZExt) + ZextOperand = N0->getOperand(0); + else + ZextOperand = N1->getOperand(0); + unsigned PreZextSizeInBits = ZextOperand.getScalarValueSizeInBits(); + KnownBits Bits = DAG.computeKnownBits(ZextOperand, 4); + KnownBits LastBitValue = Bits.extractBits(1, PreZextSizeInBits - 1); + if (LastBitValue.isZero()) { + SDNode *NewSext = + DAG.getSExtOrTrunc(ZextOperand, DL, N0->getValueType(0)).getNode(); + if (isN0ZExt) + N0 = NewSext; + else + N1 = NewSext; + NewOpc = AArch64ISD::SMULL; + } + } else if (isN0ZExt && isN1ZExt) NewOpc = AArch64ISD::UMULL; else if (isN1SExt || isN1ZExt) { // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll --- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll @@ -50,14 +50,10 @@ ; CHECK-LABEL: smull_zext_v8i8_v8i32: ; CHECK: // %bb.0: ; CHECK-NEXT: ldr d0, [x0] -; CHECK-NEXT: ldr q1, [x1] +; CHECK-NEXT: ldr q2, [x1] ; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v2.4s, v1.4h, #0 -; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0 -; CHECK-NEXT: ushll2 v3.4s, v0.8h, #0 -; CHECK-NEXT: ushll v0.4s, v0.4h, #0 -; CHECK-NEXT: mul v1.4s, v3.4s, v1.4s -; CHECK-NEXT: mul v0.4s, v0.4s, v2.4s +; CHECK-NEXT: smull2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: smull v0.4s, v0.4h, v2.4h ; CHECK-NEXT: ret %load.A = load <8 x i8>, <8 x i8>* %A %load.B = load <8 x i16>, <8 x i16>* %B @@ -74,9 +70,7 @@ ; CHECK-NEXT: ldr s0, [x0] ; CHECK-NEXT: ldr d1, [x1] ; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v1.4s, v1.4h, #0 -; CHECK-NEXT: ushll v0.4s, v0.4h, #0 -; CHECK-NEXT: mul v0.4s, v0.4s, v1.4s +; CHECK-NEXT: smull v0.4s, v0.4h, v1.4h ; CHECK-NEXT: ret %load.A = load <4 x i8>, <4 x i8>* %A %load.B = load <4 x i16>, <4 x i16>* %B @@ -114,16 +108,7 @@ ; CHECK-NEXT: ldr d0, [x0] ; CHECK-NEXT: ldr d1, [x1] ; CHECK-NEXT: bic v0.2s, #128, lsl #24 -; CHECK-NEXT: sshll v1.2d, v1.2s, #0 -; CHECK-NEXT: ushll v0.2d, v0.2s, #0 -; CHECK-NEXT: fmov x9, d1 -; CHECK-NEXT: fmov x10, d0 -; CHECK-NEXT: mov x8, v1.d[1] -; CHECK-NEXT: mov x11, v0.d[1] -; CHECK-NEXT: mul x9, x10, x9 -; CHECK-NEXT: mul x8, x11, x8 -; CHECK-NEXT: fmov d0, x9 -; CHECK-NEXT: mov v0.d[1], x8 +; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEXT: ret %load.A = load <2 x i32>, <2 x i32>* %A %and.A = and <2 x i32> %load.A, @@ -624,8 +609,9 @@ define <8 x i16> @umull_extvec_v8i8_v8i16(<8 x i8> %arg) nounwind { ; CHECK-LABEL: umull_extvec_v8i8_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.8b, #12 -; CHECK-NEXT: umull v0.8h, v0.8b, v1.8b +; CHECK-NEXT: movi v1.8h, #12 +; CHECK-NEXT: ushll v0.8h, v0.8b, #0 +; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h ; CHECK-NEXT: ret %tmp3 = zext <8 x i8> %arg to <8 x i16> %tmp4 = mul <8 x i16> %tmp3, @@ -650,8 +636,9 @@ ; CHECK-LABEL: umull_extvec_v4i16_v4i32: ; CHECK: // %bb.0: ; CHECK-NEXT: mov w8, #1234 -; CHECK-NEXT: dup v1.4h, w8 -; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h +; CHECK-NEXT: ushll v0.4s, v0.4h, #0 +; CHECK-NEXT: dup v1.4s, w8 +; CHECK-NEXT: mul v0.4s, v0.4s, v1.4s ; CHECK-NEXT: ret %tmp3 = zext <4 x i16> %arg to <4 x i32> %tmp4 = mul <4 x i32> %tmp3, @@ -661,9 +648,14 @@ define <2 x i64> @umull_extvec_v2i32_v2i64(<2 x i32> %arg) nounwind { ; CHECK-LABEL: umull_extvec_v2i32_v2i64: ; CHECK: // %bb.0: +; CHECK-NEXT: ushll v0.2d, v0.2s, #0 ; CHECK-NEXT: mov w8, #1234 -; CHECK-NEXT: dup v1.2s, w8 -; CHECK-NEXT: umull v0.2d, v0.2s, v1.2s +; CHECK-NEXT: fmov x9, d0 +; CHECK-NEXT: mov x10, v0.d[1] +; CHECK-NEXT: mul x9, x9, x8 +; CHECK-NEXT: mul x8, x10, x8 +; CHECK-NEXT: fmov d0, x9 +; CHECK-NEXT: mov v0.d[1], x8 ; CHECK-NEXT: ret %tmp3 = zext <2 x i32> %arg to <2 x i64> %tmp4 = mul <2 x i64> %tmp3,