Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -22377,6 +22377,153 @@ return DAG.getNode(ISD::BITCAST, DL, VT, NewDuplane128); } +// Try to combine op with uzp1. +static SDValue tryCombineOpWithUZP1(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + SDValue EXTRACTHIGH; + SDValue EXTRACTLOW; + SDValue TRUNCHIGH; + SDValue TRUNCLOW; + SDLoc DL(N); + + // Check the operands are trunc and extract_high. + if (isEssentiallyExtractHighSubvector(LHS) && + RHS.getOpcode() == ISD::TRUNCATE) { + TRUNCHIGH = RHS; + if (LHS.getOpcode() == ISD::BITCAST) + EXTRACTHIGH = LHS.getOperand(0); + else + EXTRACTHIGH = LHS; + } else if (isEssentiallyExtractHighSubvector(RHS) && + LHS.getOpcode() == ISD::TRUNCATE) { + TRUNCHIGH = LHS; + if (LHS.getOpcode() == ISD::BITCAST) + EXTRACTHIGH = RHS.getOperand(0); + else + EXTRACTHIGH = RHS; + } else + return SDValue(); + + // If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op + // with uzp1. + // You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll + SDValue TRUNCHIGHOP = TRUNCHIGH.getOperand(0); + EVT TRUNCHIGHOPVT = TRUNCHIGHOP.getValueType(); + if (TRUNCHIGHOP.getOpcode() == AArch64ISD::DUP || + DAG.isSplatValue(TRUNCHIGHOP, false)) + return SDValue(); + + // Check there is other extract_high with same source vector. + // For example, + // + // t18: v4i16 = extract_subvector t2, Constant:i64<0> + // t12: v4i16 = truncate t11 + // t31: v4i32 = AArch64ISD::SMULL t18, t12 + // t23: v4i16 = extract_subvector t2, Constant:i64<4> + // t16: v4i16 = truncate t15 + // t30: v4i32 = AArch64ISD::SMULL t23, t1 + // + // This dagcombine assumes the two extract_high uses same source vector in + // order to detect the pair of the Ops. If they have different source vector, + // this code will not work. + SDValue EXTRACTHIGHSrcVec = EXTRACTHIGH.getOperand(0); + if (EXTRACTHIGHSrcVec->use_size() != 2) + return SDValue(); + + // Find EXTRACTLOW. + for (SDNode::use_iterator UI = EXTRACTHIGHSrcVec.getNode()->use_begin(), + UE = EXTRACTHIGHSrcVec.getNode()->use_end(); + UI != UE; ++UI) { + SDNode *User = *UI; + if (User == EXTRACTHIGH.getNode()) + continue; + + if (User->getOpcode() != ISD::EXTRACT_SUBVECTOR) + return SDValue(); + + if (ConstantSDNode *IdxCst = + dyn_cast(User->getOperand(1))) { + if (!IdxCst->isZero()) + return SDValue(); + } else + return SDValue(); + + EXTRACTLOW.setNode(User); + } + + // Check EXTRACTLOW's user. + if (!EXTRACTLOW->hasOneUse()) + return SDValue(); + + SDNode::use_iterator UI = EXTRACTLOW.getNode()->use_begin(); + SDNode *EXTRACTLOWUser = *UI; + if (EXTRACTLOWUser->getOpcode() != N->getOpcode()) + return SDValue(); + + if (EXTRACTLOWUser->getOperand(0).getNode() == EXTRACTLOW.getNode()) { + if (EXTRACTLOWUser->getOperand(1).getOpcode() == ISD::TRUNCATE) + TRUNCLOW = EXTRACTLOWUser->getOperand(1); + else + return SDValue(); + } else { + if (EXTRACTLOWUser->getOperand(0).getOpcode() == ISD::TRUNCATE) + TRUNCLOW = EXTRACTLOWUser->getOperand(0); + else + return SDValue(); + } + + // If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op + // with uzp1. + // You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll + SDValue TRUNCLOWOP = TRUNCLOW.getOperand(0); + EVT TRUNCLOWOPVT = TRUNCLOWOP.getValueType(); + if (TRUNCLOWOP.getOpcode() == AArch64ISD::DUP || + DAG.isSplatValue(TRUNCLOWOP, false)) + return SDValue(); + + // Create uzp1, extract_high and extract_low. + EVT TRUNCHIGHVT = TRUNCHIGH.getValueType(); + EVT TRUNCLOWVT = TRUNCLOW.getValueType(); + EVT UZP1VT = TRUNCHIGHVT.getDoubleNumVectorElementsVT(*DAG.getContext()); + + if (TRUNCHIGHOPVT != UZP1VT) + TRUNCHIGHOP = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TRUNCHIGHOP); + if (TRUNCLOWOPVT != UZP1VT) + TRUNCLOWOP = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TRUNCLOWOP); + + SDValue UZP1 = + DAG.getNode(AArch64ISD::UZP1, DL, UZP1VT, TRUNCLOWOP, TRUNCHIGHOP); + SDValue NewTRUNCHIGH = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TRUNCHIGHVT, + UZP1, EXTRACTHIGH.getOperand(1)); + SDValue NewTRUNCLOW = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TRUNCLOWVT, + UZP1, EXTRACTLOW.getOperand(1)); + + DAG.ReplaceAllUsesWith(TRUNCHIGH, NewTRUNCHIGH); + DAG.ReplaceAllUsesWith(TRUNCLOW, NewTRUNCLOW); + + return SDValue(N, 0); +} + +static SDValue performMULLCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (SDValue Val = + tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG)) + return Val; + + if (SDValue Val = tryCombineOpWithUZP1(N, DCI, DAG)) + return Val; + + return SDValue(); +} + SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -22521,7 +22668,7 @@ case AArch64ISD::SMULL: case AArch64ISD::UMULL: case AArch64ISD::PMULL: - return tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG); + return performMULLCombine(N, DCI, DAG); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast(N->getOperand(1))->getZExtValue()) { Index: llvm/test/CodeGen/AArch64/aarch64-smull.ll =================================================================== --- llvm/test/CodeGen/AArch64/aarch64-smull.ll +++ llvm/test/CodeGen/AArch64/aarch64-smull.ll @@ -1033,13 +1033,11 @@ ; CHECK-LABEL: umull_and_v8i32: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v3.2d, #0x0000ff000000ff -; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 ; CHECK-NEXT: and v2.16b, v2.16b, v3.16b ; CHECK-NEXT: and v1.16b, v1.16b, v3.16b -; CHECK-NEXT: xtn v1.4h, v1.4s -; CHECK-NEXT: xtn v2.4h, v2.4s -; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h -; CHECK-NEXT: umull v1.4s, v4.4h, v2.4h +; CHECK-NEXT: uzp1 v2.8h, v1.8h, v2.8h +; CHECK-NEXT: umull2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: umull v0.4s, v0.4h, v2.4h ; CHECK-NEXT: ret entry: %in1 = zext <8 x i16> %src1 to <8 x i32> @@ -1084,13 +1082,11 @@ ; CHECK-LABEL: umull_and_v4i64: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v3.2d, #0x000000000000ff -; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 ; CHECK-NEXT: and v2.16b, v2.16b, v3.16b ; CHECK-NEXT: and v1.16b, v1.16b, v3.16b -; CHECK-NEXT: xtn v1.2s, v1.2d -; CHECK-NEXT: xtn v2.2s, v2.2d -; CHECK-NEXT: umull v0.2d, v0.2s, v1.2s -; CHECK-NEXT: umull v1.2d, v4.2s, v2.2s +; CHECK-NEXT: uzp1 v2.4s, v1.4s, v2.4s +; CHECK-NEXT: umull2 v1.2d, v0.4s, v2.4s +; CHECK-NEXT: umull v0.2d, v0.2s, v2.2s ; CHECK-NEXT: ret entry: %in1 = zext <4 x i32> %src1 to <4 x i64> @@ -1115,3 +1111,136 @@ %out = mul nsw <4 x i64> %in1, %broadcast.splat ret <4 x i64> %out } + +define void @pmlsl_pmlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: pmlsl_pmlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.16b, v2.16b, v3.16b +; CHECK-NEXT: pmull v3.8h, v0.8b, v2.8b +; CHECK-NEXT: pmull2 v0.8h, v0.16b, v2.16b +; CHECK-NEXT: add v0.8h, v3.8h, v0.8h +; CHECK-NEXT: sub v0.8h, v1.8h, v0.8h +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @smlsl_smlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: smlsl_smlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.16b, v2.16b, v3.16b +; CHECK-NEXT: smlsl v1.8h, v0.8b, v2.8b +; CHECK-NEXT: smlsl2 v1.8h, v0.16b, v2.16b +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @umlsl_umlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: umlsl_umlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.16b, v2.16b, v3.16b +; CHECK-NEXT: umlsl v1.8h, v0.8b, v2.8b +; CHECK-NEXT: umlsl2 v1.8h, v0.16b, v2.16b +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @smlsl_smlsl2_v4i32_uzp1(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: smlsl_smlsl2_v4i32_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.8h, v2.8h, v3.8h +; CHECK-NEXT: smlsl v1.4s, v0.4h, v2.4h +; CHECK-NEXT: smlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <4 x i32>, ptr %3, align 4 + %6 = trunc <4 x i32> %5 to <4 x i16> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <4 x i32>, ptr %7, align 4 + %9 = trunc <4 x i32> %8 to <4 x i16> + %10 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %11 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %10, <4 x i16> %6) + %12 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %13 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %12, <4 x i16> %9) + %14 = add <4 x i32> %11, %13 + %15 = sub <4 x i32> %1, %14 + store <4 x i32> %15, ptr %2, align 16 + ret void +} + +define void @umlsl_umlsl2_v4i32_uzp1(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: umlsl_umlsl2_v4i32_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.8h, v2.8h, v3.8h +; CHECK-NEXT: umlsl v1.4s, v0.4h, v2.4h +; CHECK-NEXT: umlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <4 x i32>, ptr %3, align 4 + %6 = trunc <4 x i32> %5 to <4 x i16> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <4 x i32>, ptr %7, align 4 + %9 = trunc <4 x i32> %8 to <4 x i16> + %10 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %11 = tail call <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16> %10, <4 x i16> %6) + %12 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %13 = tail call <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16> %12, <4 x i16> %9) + %14 = add <4 x i32> %11, %13 + %15 = sub <4 x i32> %1, %14 + store <4 x i32> %15, ptr %2, align 16 + ret void +} + +declare <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8>, <8 x i8>) +declare <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8>, <8 x i8>) +declare <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8>, <8 x i8>) +declare <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16>, <4 x i16>) +declare <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16>, <4 x i16>)