Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -18343,6 +18343,230 @@ DAG.getConstant(0, DL, MVT::i64)); } +static bool CheckMULWithTRUNCAndEXTRACTVEC(SDValue &MUL, SDValue &EXTRACT, + SDValue &TRUNC, bool HighPart) { + if (MUL.getOpcode() != AArch64ISD::SMULL && + MUL.getOpcode() != AArch64ISD::UMULL && + MUL.getOpcode() != AArch64ISD::PMULL) + return false; + + auto CheckMULOps = [](SDValue EXTRACTOp, SDValue TRUNCOp, bool HighPart) { + if (EXTRACTOp.getOpcode() != ISD::EXTRACT_SUBVECTOR) + return false; + + if (HighPart) { + unsigned HighIdx = + EXTRACTOp.getOperand(0).getValueType().getVectorNumElements() / 2; + if (ConstantSDNode *HighCst = + dyn_cast(EXTRACTOp.getOperand(1))) + if (HighCst->getZExtValue() != HighIdx) + return false; + } + + if (TRUNCOp.getOpcode() != ISD::TRUNCATE) + return false; + + return true; + }; + + // Check MUL EXTRACTVEC, TRUNC. + if (CheckMULOps(MUL.getOperand(0), MUL.getOperand(1), HighPart)) { + EXTRACT = MUL.getOperand(0); + TRUNC = MUL.getOperand(1); + return true; + } + + // Check MUL TRUNC, EXTRACTVEC. + if (CheckMULOps(MUL.getOperand(1), MUL.getOperand(0), HighPart)) { + EXTRACT = MUL.getOperand(1); + TRUNC = MUL.getOperand(0); + return true; + } + + return false; +} + +// Try to combine nodes into uzp1, smlsl and smlsl2. +static SDValue +performSubMULExtractSubVecCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SelectionDAG &DAG = DCI.DAG; + + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + EVT VT = N->getValueType(0); + if (!VT.isFixedLengthVector()) + return SDValue(); + + // Check below code pattern. + // + // t2: v8i16,ch = CopyFromReg t0, Register:v8i16 %0 + // t8: i64,ch = CopyFromReg t0, Register:i64 %3 + // t11: v4i32,ch = load<(load (s128) from %ir.3, align 4)> t0, t8, undef:i64 + // t14: i64 = add nuw t8, Constant:i64<16> + // t15: v4i32,ch = load<(load (s128) from %ir.7, align 4)> t0, t14, undef:i64 + // t27: ch = TokenFactor t11:1, t15:1 + // t4: v4i32,ch = CopyFromReg t0, Register:v4i32 %1 + // t18: v4i16 = extract_subvector t2, Constant:i64<0> + // t12: v4i16 = truncate t11 + // t31: v4i32 = AArch64ISD::SMULL t18, t12 + // t32: v4i32 = sub t4, t31 + // t23: v4i16 = extract_subvector t2, Constant:i64<4> + // t16: v4i16 = truncate t15 + // t30: v4i32 = AArch64ISD::SMULL t23, t16 + // t33: v4i32 = sub t32, t30 + if (N->getOpcode() != ISD::SUB) + return SDValue(); + + // It looks DAG can have two code patterns as below. + // + // first: + // t31: v4i32 = AArch64ISD::SMULL t18, t12 + // t32: v4i32 = sub t4, t31 + // t30: v4i32 = AArch64ISD::SMULL t23, t16 + // t33: v4i32 = sub t32, t30 + // + // second: + // t33: v8i16 = AArch64ISD::PMULL t20, t14 + // t32: v8i16 = AArch64ISD::PMULL t25, t18 + // t27: v8i16 = add t33, t32 + // t28: v8i16 = sub t4, t27 + // + // Let's check both cases. + if (N->getOperand(0).getOpcode() != ISD::SUB && + N->getOperand(1).getOpcode() != ISD::ADD) + return SDValue(); + + bool HasSUBSUB = false; + SDValue MULLow; + SDValue EXTRACTLow; + SDValue TRUNCLow; + SDValue MULHigh; + SDValue EXTRACTHigh; + SDValue TRUNCHigh; + if (N->getOperand(0).getOpcode() == ISD::SUB) { + // Let's check the case with sub and sub. + // + // t31: v4i32 = AArch64ISD::SMULL t18, t12 + // t32: v4i32 = sub t4, t31 + // t30: v4i32 = AArch64ISD::SMULL t23, t16 + // t33: v4i32 = sub t32, t30 + + // Check low part MUL. + MULLow = N->getOperand(0).getOperand(1); + if (!CheckMULWithTRUNCAndEXTRACTVEC(MULLow, EXTRACTLow, TRUNCLow, false)) + return SDValue(); + + // Check high part MUL. + MULHigh = N->getOperand(1); + if (!CheckMULWithTRUNCAndEXTRACTVEC(MULHigh, EXTRACTHigh, TRUNCHigh, true)) + return SDValue(); + + HasSUBSUB = true; + } else { + // Let's check the case with sub and add. + // + // t33: v8i16 = AArch64ISD::PMULL t20, t14 + // t32: v8i16 = AArch64ISD::PMULL t25, t18 + // t27: v8i16 = add t33, t32 + // t28: v8i16 = sub t4, t27 + SDValue ADD = N->getOperand(1); + + // Check low part MUL. + bool IsLHSMULLow = false; + bool IsRHSMULLow = false; + MULLow = ADD.getOperand(0); + if (CheckMULWithTRUNCAndEXTRACTVEC(MULLow, EXTRACTLow, TRUNCLow, false)) + IsLHSMULLow = true; + else { + MULLow = ADD.getOperand(1); + if (CheckMULWithTRUNCAndEXTRACTVEC(MULLow, EXTRACTLow, TRUNCLow, false)) { + IsRHSMULLow = true; + } else + return SDValue(); + } + + // Check high part MUL. + if (IsLHSMULLow) + MULHigh = ADD.getOperand(1); + else if (IsRHSMULLow) + MULHigh = ADD.getOperand(0); + + if (!CheckMULWithTRUNCAndEXTRACTVEC(MULHigh, EXTRACTHigh, TRUNCHigh, true)) + return SDValue(); + } + + // We have found the code pattern. Let's build below code sequence. + // + // t2: v8i16,ch = CopyFromReg t0, Register:v8i16 %0 + // t8: i64,ch = CopyFromReg t0, Register:i64 %3 + // t36: v8i16 = AArch64ISD::UZP1 t44, t43 + // t14: i64 = add nuw t8, Constant:i64<16> + // t43: v8i16,ch = load<(load (s128) from %ir.7, align 4)> t0, t14, undef:i64 + // t44: v8i16,ch = load<(load (s128) from %ir.3, align 4)> t0, t8, undef:i64 + // t27: ch = TokenFactor t44:1, t43:1 + // t4: v4i32,ch = CopyFromReg t0, Register:v4i32 %1 + // t18: v4i16 = extract_subvector t2, Constant:i64<0> + // t37: v4i16 = extract_subvector t36, Constant:i64<0> + // t39: v4i32 = AArch64ISD::SMULL t18, t37 + // t41: v4i32 = sub t4, t39 + // t23: v4i16 = extract_subvector t2, Constant:i64<4> + // t38: v4i16 = extract_subvector t36, Constant:i64<4> + // t40: v4i32 = AArch64ISD::SMULL t23, t38 + // t42: v4i32 = sub t41, t40 + + // Create UZP1. + SDLoc DL(N); + EVT TRUNCVT = TRUNCLow.getValueType(); + EVT UZP1VT = TRUNCVT.getDoubleNumVectorElementsVT(*DAG.getContext()); + SDValue TRUNCHighOP0 = TRUNCHigh.getOperand(0); + SDValue TRUNCLowOP0 = TRUNCLow.getOperand(0); + EVT TRUNCLowOP0VT = TRUNCLowOP0.getValueType(); + if (TRUNCLowOP0VT != UZP1VT) { + TRUNCHighOP0 = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TRUNCHighOP0); + TRUNCLowOP0 = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TRUNCLowOP0); + } + SDValue UZP1 = + DAG.getNode(AArch64ISD::UZP1, DL, UZP1VT, TRUNCLowOP0, TRUNCHighOP0); + + // Create EXTRACT_SUBVECTOR low. + SDValue NewTRUNCLow = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TRUNCVT, UZP1, + EXTRACTLow.getOperand(1)); + + // Create EXTRACT_SUBVECTOR high. + SDValue NewTRUNCHigh = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TRUNCVT, UZP1, + EXTRACTHigh.getOperand(1)); + + // Create new MUL low. + SDValue NewMULLow = DAG.getNode( + MULLow->getOpcode(), DL, MULLow.getValueType(), EXTRACTLow, NewTRUNCLow); + + // Create new MUL high. + SDValue NewMULHigh = + DAG.getNode(MULHigh->getOpcode(), DL, MULHigh.getValueType(), EXTRACTHigh, + NewTRUNCHigh); + + SDValue Res; + if (HasSUBSUB) { + // Create new SUB + SDValue NewSUB = DAG.getNode(N->getOperand(0)->getOpcode(), DL, + N->getOperand(0).getValueType(), + N->getOperand(0).getOperand(0), NewMULLow); + Res = + DAG.getNode(N->getOpcode(), DL, N->getValueType(0), NewSUB, NewMULHigh); + } else { + // Create new ADD. + SDValue NewADD = + DAG.getNode(N->getOperand(1)->getOpcode(), DL, + N->getOperand(1).getValueType(), NewMULLow, NewMULHigh); + Res = DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0), + NewADD); + } + + return Res; +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // Try to change sum of two reductions. @@ -18364,6 +18588,8 @@ return Val; if (SDValue Val = performAddSubIntoVectorOp(N, DCI.DAG)) return Val; + if (SDValue Val = performSubMULExtractSubVecCombine(N, DCI)) + return Val; return performAddSubLongCombine(N, DCI); } Index: llvm/test/CodeGen/AArch64/aarch64-smull.ll =================================================================== --- llvm/test/CodeGen/AArch64/aarch64-smull.ll +++ llvm/test/CodeGen/AArch64/aarch64-smull.ll @@ -1115,3 +1115,136 @@ %out = mul nsw <4 x i64> %in1, %broadcast.splat ret <4 x i64> %out } + +define void @pmlsl_pmlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: pmlsl_pmlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q3, q2, [x1] +; CHECK-NEXT: uzp1 v2.16b, v3.16b, v2.16b +; CHECK-NEXT: pmull2 v3.8h, v0.16b, v2.16b +; CHECK-NEXT: pmull v0.8h, v0.8b, v2.8b +; CHECK-NEXT: add v0.8h, v0.8h, v3.8h +; CHECK-NEXT: sub v0.8h, v1.8h, v0.8h +; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @smlsl_smlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: smlsl_smlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q3, q2, [x1] +; CHECK-NEXT: uzp1 v2.16b, v3.16b, v2.16b +; CHECK-NEXT: smlsl v1.8h, v0.8b, v2.8b +; CHECK-NEXT: smlsl2 v1.8h, v0.16b, v2.16b +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @umlsl_umlsl2_v8i16_uzp1(<16 x i8> %0, <8 x i16> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: umlsl_umlsl2_v8i16_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q3, q2, [x1] +; CHECK-NEXT: uzp1 v2.16b, v3.16b, v2.16b +; CHECK-NEXT: umlsl v1.8h, v0.8b, v2.8b +; CHECK-NEXT: umlsl2 v1.8h, v0.16b, v2.16b +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <8 x i16>, ptr %3, align 4 + %6 = trunc <8 x i16> %5 to <8 x i8> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <8 x i16>, ptr %7, align 4 + %9 = trunc <8 x i16> %8 to <8 x i8> + %10 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %11 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %10, <8 x i8> %6) + %12 = shufflevector <16 x i8> %0, <16 x i8> poison, <8 x i32> + %13 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %12, <8 x i8> %9) + %14 = add <8 x i16> %11, %13 + %15 = sub <8 x i16> %1, %14 + store <8 x i16> %15, ptr %2, align 16 + ret void +} + +define void @smlsl_smlsl2_v4i32_uzp1(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: smlsl_smlsl2_v4i32_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q3, q2, [x1] +; CHECK-NEXT: uzp1 v2.8h, v3.8h, v2.8h +; CHECK-NEXT: smlsl v1.4s, v0.4h, v2.4h +; CHECK-NEXT: smlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <4 x i32>, ptr %3, align 4 + %6 = trunc <4 x i32> %5 to <4 x i16> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <4 x i32>, ptr %7, align 4 + %9 = trunc <4 x i32> %8 to <4 x i16> + %10 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %11 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %10, <4 x i16> %6) + %12 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %13 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %12, <4 x i16> %9) + %14 = add <4 x i32> %11, %13 + %15 = sub <4 x i32> %1, %14 + store <4 x i32> %15, ptr %2, align 16 + ret void +} + +define void @umlsl_umlsl2_v4i32_uzp1(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: umlsl_umlsl2_v4i32_uzp1: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q3, q2, [x1] +; CHECK-NEXT: uzp1 v2.8h, v3.8h, v2.8h +; CHECK-NEXT: umlsl v1.4s, v0.4h, v2.4h +; CHECK-NEXT: umlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <4 x i32>, ptr %3, align 4 + %6 = trunc <4 x i32> %5 to <4 x i16> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <4 x i32>, ptr %7, align 4 + %9 = trunc <4 x i32> %8 to <4 x i16> + %10 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %11 = tail call <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16> %10, <4 x i16> %6) + %12 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %13 = tail call <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16> %12, <4 x i16> %9) + %14 = add <4 x i32> %11, %13 + %15 = sub <4 x i32> %1, %14 + store <4 x i32> %15, ptr %2, align 16 + ret void +} + +declare <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8>, <8 x i8>) +declare <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8>, <8 x i8>) +declare <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8>, <8 x i8>) +declare <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16>, <4 x i16>) +declare <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16>, <4 x i16>)