diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -47524,71 +47524,100 @@ return Res; } -// Try to combine the following nodes -// t21: v16f32 = X86ISD::VFMULC/VFCMULC t7, t8 -// t15: v32f16 = bitcast t21 -// t16: v32f16 = fadd nnan ninf nsz arcp contract afn reassoc t15, t2 -// into X86ISD::VFMADDC/VFCMADDC if possible: -// t22: v16f32 = bitcast t2 -// t23: v16f32 = nnan ninf nsz arcp contract afn reassoc -// X86ISD::VFMADDC/VFCMADDC t7, t8, t22 -// t24: v32f16 = bitcast t23 +// This function transforms complex FADD(A, FMA(B, C, 0)) and FADD(A, FMUL(B, C)) +// to FMA(B, C, A). static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { - auto AllowContract = [&DAG](SDNode *N) { - return DAG.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast || - N->getFlags().hasAllowContract(); - }; - if (N->getOpcode() != ISD::FADD || !Subtarget.hasFP16() || !AllowContract(N)) - return SDValue(); +// Check the -ffp-contract option or the fast-math flag of SDNode. +auto AllowContract = [&DAG](const SDNodeFlags &Flags) { + return DAG.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast || + (Flags.hasNoNaNs() && Flags.hasNoInfs() && Flags.hasNoSignedZeros() && + Flags.hasAllowReciprocal() && Flags.hasAllowContract() && + Flags.hasApproximateFuncs() && Flags.hasAllowReassociation()); +}; - EVT VT = N->getValueType(0); - if (VT != MVT::v8f16 && VT != MVT::v16f16 && VT != MVT::v32f16) - return SDValue(); +auto IsVectorAllNegativeZero = [](const SDNode *N) { + if (N->getOpcode() != X86ISD::VBROADCAST_LOAD) + return false; + assert(N->getSimpleValueType(0).getScalarType() == MVT::f32 && + "Unexpected vector type!"); + if (ConstantPoolSDNode *CP = + dyn_cast(N->getOperand(1)->getOperand(0))) { + APInt AI = APInt(32, 0x80008000, true); + if (const auto *CI = dyn_cast(CP->getConstVal())) + return CI->getValue() == AI; + if (const auto *CF = dyn_cast(CP->getConstVal())) + return CF->getValue() == APFloat(APFloat::IEEEsingle(), AI); + } + return false; +}; - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - SDValue CFmul, FAddOp1; - auto GetCFmulFrom = [&CFmul, &AllowContract](SDValue N) -> bool { - if (!N.hasOneUse() || N.getOpcode() != ISD::BITCAST) - return false; - SDValue Op0 = N.getOperand(0); - unsigned Opcode = Op0.getOpcode(); - if (Op0.hasOneUse() && AllowContract(Op0.getNode()) && - (Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC)) - CFmul = Op0; - return !!CFmul; - }; +if (N->getOpcode() != ISD::FADD || !Subtarget.hasFP16() || + !AllowContract(N->getFlags())) + return SDValue(); - if (GetCFmulFrom(LHS)) - FAddOp1 = RHS; - else if (GetCFmulFrom(RHS)) - FAddOp1 = LHS; - else - return SDValue(); +EVT VT = N->getValueType(0); +if (VT != MVT::v8f16 && VT != MVT::v16f16 && VT != MVT::v32f16) + return SDValue(); + +SDValue LHS = N->getOperand(0); +SDValue RHS = N->getOperand(1); +bool IsConj; +SDValue FAddOp1, MulOp0, MulOp1; +auto GetCFmulFrom = [&MulOp0, &MulOp1, &IsConj, &AllowContract, + &IsVectorAllNegativeZero](SDValue N) -> bool { + if (!N.hasOneUse() || N.getOpcode() != ISD::BITCAST) + return false; + SDValue Op0 = N.getOperand(0); + unsigned Opcode = Op0.getOpcode(); + if (Op0.hasOneUse()) { + if ((Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC) && + AllowContract(Op0->getFlags())) { + MulOp0 = Op0.getOperand(0); + MulOp1 = Op0.getOperand(1); + IsConj = Opcode == X86ISD::VFCMULC; + return true; + } + if ((Opcode == X86ISD::VFMADDC || Opcode == X86ISD::VFCMADDC) && + ((ISD::isBuildVectorAllZeros(Op0->getOperand(0).getNode()) && + AllowContract(Op0->getFlags())) || + IsVectorAllNegativeZero(Op0->getOperand(0).getNode()))) { + MulOp0 = Op0.getOperand(1); + MulOp1 = Op0.getOperand(2); + IsConj = Opcode == X86ISD::VFCMADDC; + return true; + } + } + return false; +}; + +if (GetCFmulFrom(LHS)) + FAddOp1 = RHS; +else if (GetCFmulFrom(RHS)) + FAddOp1 = LHS; +else + return SDValue(); - MVT CVT = MVT::getVectorVT(MVT::f32, VT.getVectorNumElements() / 2); - assert(CFmul->getValueType(0) == CVT && "Complex type mismatch"); - FAddOp1 = DAG.getBitcast(CVT, FAddOp1); - unsigned newOp = CFmul.getOpcode() == X86ISD::VFMULC ? X86ISD::VFMADDC - : X86ISD::VFCMADDC; - // FIXME: How do we handle when fast math flags of FADD are different from - // CFMUL's? - CFmul = DAG.getNode(newOp, SDLoc(N), CVT, FAddOp1, CFmul.getOperand(0), - CFmul.getOperand(1), N->getFlags()); - return DAG.getBitcast(VT, CFmul); +MVT CVT = MVT::getVectorVT(MVT::f32, VT.getVectorNumElements() / 2); +FAddOp1 = DAG.getBitcast(CVT, FAddOp1); +unsigned NewOp = IsConj ? X86ISD::VFCMADDC : X86ISD::VFMADDC; +// FIXME: How do we handle when fast math flags of FADD are different from +// CFMUL's? +SDValue CFmul = + DAG.getNode(NewOp, SDLoc(N), CVT, FAddOp1, MulOp0, MulOp1, N->getFlags()); +return DAG.getBitcast(VT, CFmul); } /// Do target-specific dag combines on floating-point adds/subs. static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - if (SDValue HOp = combineToHorizontalAddSub(N, DAG, Subtarget)) - return HOp; + const X86Subtarget &Subtarget) { +if (SDValue HOp = combineToHorizontalAddSub(N, DAG, Subtarget)) + return HOp; - if (SDValue COp = combineFaddCFmul(N, DAG, Subtarget)) - return COp; +if (SDValue COp = combineFaddCFmul(N, DAG, Subtarget)) + return COp; - return SDValue(); +return SDValue(); } /// Attempt to pre-truncate inputs to arithmetic ops if it will simplify @@ -47597,77 +47626,77 @@ /// TODO: This overlaps with the generic combiner's visitTRUNCATE. Remove /// anything that is guaranteed to be transformed by DAGCombiner. static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget, - const SDLoc &DL) { - assert(N->getOpcode() == ISD::TRUNCATE && "Wrong opcode"); - SDValue Src = N->getOperand(0); - unsigned SrcOpcode = Src.getOpcode(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - - EVT VT = N->getValueType(0); - EVT SrcVT = Src.getValueType(); - - auto IsFreeTruncation = [VT](SDValue Op) { - unsigned TruncSizeInBits = VT.getScalarSizeInBits(); + const X86Subtarget &Subtarget, + const SDLoc &DL) { +assert(N->getOpcode() == ISD::TRUNCATE && "Wrong opcode"); +SDValue Src = N->getOperand(0); +unsigned SrcOpcode = Src.getOpcode(); +const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - // See if this has been extended from a smaller/equal size to - // the truncation size, allowing a truncation to combine with the extend. - unsigned Opcode = Op.getOpcode(); - if ((Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND || - Opcode == ISD::ZERO_EXTEND) && - Op.getOperand(0).getScalarValueSizeInBits() <= TruncSizeInBits) - return true; +EVT VT = N->getValueType(0); +EVT SrcVT = Src.getValueType(); - // See if this is a single use constant which can be constant folded. - // NOTE: We don't peek throught bitcasts here because there is currently - // no support for constant folding truncate+bitcast+vector_of_constants. So - // we'll just send up with a truncate on both operands which will - // get turned back into (truncate (binop)) causing an infinite loop. - return ISD::isBuildVectorOfConstantSDNodes(Op.getNode()); - }; +auto IsFreeTruncation = [VT](SDValue Op) { + unsigned TruncSizeInBits = VT.getScalarSizeInBits(); - auto TruncateArithmetic = [&](SDValue N0, SDValue N1) { - SDValue Trunc0 = DAG.getNode(ISD::TRUNCATE, DL, VT, N0); - SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, VT, N1); - return DAG.getNode(SrcOpcode, DL, VT, Trunc0, Trunc1); - }; + // See if this has been extended from a smaller/equal size to + // the truncation size, allowing a truncation to combine with the extend. + unsigned Opcode = Op.getOpcode(); + if ((Opcode == ISD::ANY_EXTEND || Opcode == ISD::SIGN_EXTEND || + Opcode == ISD::ZERO_EXTEND) && + Op.getOperand(0).getScalarValueSizeInBits() <= TruncSizeInBits) + return true; - // Don't combine if the operation has other uses. - if (!Src.hasOneUse()) - return SDValue(); + // See if this is a single use constant which can be constant folded. + // NOTE: We don't peek throught bitcasts here because there is currently + // no support for constant folding truncate+bitcast+vector_of_constants. So + // we'll just send up with a truncate on both operands which will + // get turned back into (truncate (binop)) causing an infinite loop. + return ISD::isBuildVectorOfConstantSDNodes(Op.getNode()); +}; - // Only support vector truncation for now. - // TODO: i64 scalar math would benefit as well. - if (!VT.isVector()) - return SDValue(); +auto TruncateArithmetic = [&](SDValue N0, SDValue N1) { + SDValue Trunc0 = DAG.getNode(ISD::TRUNCATE, DL, VT, N0); + SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, VT, N1); + return DAG.getNode(SrcOpcode, DL, VT, Trunc0, Trunc1); +}; - // In most cases its only worth pre-truncating if we're only facing the cost - // of one truncation. - // i.e. if one of the inputs will constant fold or the input is repeated. - switch (SrcOpcode) { - case ISD::MUL: - // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its - // better to truncate if we have the chance. - if (SrcVT.getScalarType() == MVT::i64 && - TLI.isOperationLegal(SrcOpcode, VT) && - !TLI.isOperationLegal(SrcOpcode, SrcVT)) - return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); - LLVM_FALLTHROUGH; - case ISD::AND: - case ISD::XOR: - case ISD::OR: - case ISD::ADD: - case ISD::SUB: { - SDValue Op0 = Src.getOperand(0); - SDValue Op1 = Src.getOperand(1); - if (TLI.isOperationLegal(SrcOpcode, VT) && - (Op0 == Op1 || IsFreeTruncation(Op0) || IsFreeTruncation(Op1))) - return TruncateArithmetic(Op0, Op1); - break; - } - } +// Don't combine if the operation has other uses. +if (!Src.hasOneUse()) + return SDValue(); +// Only support vector truncation for now. +// TODO: i64 scalar math would benefit as well. +if (!VT.isVector()) return SDValue(); + +// In most cases its only worth pre-truncating if we're only facing the cost +// of one truncation. +// i.e. if one of the inputs will constant fold or the input is repeated. +switch (SrcOpcode) { +case ISD::MUL: + // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its + // better to truncate if we have the chance. + if (SrcVT.getScalarType() == MVT::i64 && + TLI.isOperationLegal(SrcOpcode, VT) && + !TLI.isOperationLegal(SrcOpcode, SrcVT)) + return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1)); + LLVM_FALLTHROUGH; +case ISD::AND: +case ISD::XOR: +case ISD::OR: +case ISD::ADD: +case ISD::SUB: { + SDValue Op0 = Src.getOperand(0); + SDValue Op1 = Src.getOperand(1); + if (TLI.isOperationLegal(SrcOpcode, VT) && + (Op0 == Op1 || IsFreeTruncation(Op0) || IsFreeTruncation(Op1))) + return TruncateArithmetic(Op0, Op1); + break; +} +} + +return SDValue(); } /// Truncate using ISD::AND mask and X86ISD::PACKUS. @@ -47675,28 +47704,28 @@ /// MaskX = X & 0xffff (clear high bits to prevent saturation) /// packus (extract_subv MaskX, 0), (extract_subv MaskX, 1) static SDValue combineVectorTruncationWithPACKUS(SDNode *N, const SDLoc &DL, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - SDValue In = N->getOperand(0); - EVT InVT = In.getValueType(); - EVT OutVT = N->getValueType(0); + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { +SDValue In = N->getOperand(0); +EVT InVT = In.getValueType(); +EVT OutVT = N->getValueType(0); - APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(), - OutVT.getScalarSizeInBits()); - In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT)); - return truncateVectorWithPACK(X86ISD::PACKUS, OutVT, In, DL, DAG, Subtarget); +APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(), + OutVT.getScalarSizeInBits()); +In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT)); +return truncateVectorWithPACK(X86ISD::PACKUS, OutVT, In, DL, DAG, Subtarget); } /// Truncate a group of v4i32 into v8i16 using X86ISD::PACKSS. static SDValue combineVectorTruncationWithPACKSS(SDNode *N, const SDLoc &DL, - const X86Subtarget &Subtarget, - SelectionDAG &DAG) { - SDValue In = N->getOperand(0); - EVT InVT = In.getValueType(); - EVT OutVT = N->getValueType(0); - In = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, InVT, In, - DAG.getValueType(OutVT)); - return truncateVectorWithPACK(X86ISD::PACKSS, OutVT, In, DL, DAG, Subtarget); + const X86Subtarget &Subtarget, + SelectionDAG &DAG) { +SDValue In = N->getOperand(0); +EVT InVT = In.getValueType(); +EVT OutVT = N->getValueType(0); +In = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, InVT, In, + DAG.getValueType(OutVT)); +return truncateVectorWithPACK(X86ISD::PACKSS, OutVT, In, DL, DAG, Subtarget); } /// This function transforms truncation from vXi32/vXi64 to vXi8/vXi16 into @@ -47705,130 +47734,130 @@ /// element that is extracted from a vector and then truncated, and it is /// difficult to do this optimization based on them. static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - EVT OutVT = N->getValueType(0); - if (!OutVT.isVector()) - return SDValue(); + const X86Subtarget &Subtarget) { +EVT OutVT = N->getValueType(0); +if (!OutVT.isVector()) + return SDValue(); - SDValue In = N->getOperand(0); - if (!In.getValueType().isSimple()) - return SDValue(); +SDValue In = N->getOperand(0); +if (!In.getValueType().isSimple()) + return SDValue(); - EVT InVT = In.getValueType(); - unsigned NumElems = OutVT.getVectorNumElements(); +EVT InVT = In.getValueType(); +unsigned NumElems = OutVT.getVectorNumElements(); - // AVX512 provides fast truncate ops. - if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) - return SDValue(); +// AVX512 provides fast truncate ops. +if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) + return SDValue(); - EVT OutSVT = OutVT.getVectorElementType(); - EVT InSVT = InVT.getVectorElementType(); - if (!((InSVT == MVT::i16 || InSVT == MVT::i32 || InSVT == MVT::i64) && - (OutSVT == MVT::i8 || OutSVT == MVT::i16) && isPowerOf2_32(NumElems) && - NumElems >= 8)) - return SDValue(); +EVT OutSVT = OutVT.getVectorElementType(); +EVT InSVT = InVT.getVectorElementType(); +if (!((InSVT == MVT::i16 || InSVT == MVT::i32 || InSVT == MVT::i64) && + (OutSVT == MVT::i8 || OutSVT == MVT::i16) && isPowerOf2_32(NumElems) && + NumElems >= 8)) + return SDValue(); - // SSSE3's pshufb results in less instructions in the cases below. - if (Subtarget.hasSSSE3() && NumElems == 8 && InSVT != MVT::i64) - return SDValue(); +// SSSE3's pshufb results in less instructions in the cases below. +if (Subtarget.hasSSSE3() && NumElems == 8 && InSVT != MVT::i64) + return SDValue(); - SDLoc DL(N); - // SSE2 provides PACKUS for only 2 x v8i16 -> v16i8 and SSE4.1 provides PACKUS - // for 2 x v4i32 -> v8i16. For SSSE3 and below, we need to use PACKSS to - // truncate 2 x v4i32 to v8i16. - if (Subtarget.hasSSE41() || OutSVT == MVT::i8) - return combineVectorTruncationWithPACKUS(N, DL, Subtarget, DAG); - if (InSVT == MVT::i32) - return combineVectorTruncationWithPACKSS(N, DL, Subtarget, DAG); +SDLoc DL(N); +// SSE2 provides PACKUS for only 2 x v8i16 -> v16i8 and SSE4.1 provides PACKUS +// for 2 x v4i32 -> v8i16. For SSSE3 and below, we need to use PACKSS to +// truncate 2 x v4i32 to v8i16. +if (Subtarget.hasSSE41() || OutSVT == MVT::i8) + return combineVectorTruncationWithPACKUS(N, DL, Subtarget, DAG); +if (InSVT == MVT::i32) + return combineVectorTruncationWithPACKSS(N, DL, Subtarget, DAG); - return SDValue(); +return SDValue(); } /// This function transforms vector truncation of 'extended sign-bits' or /// 'extended zero-bits' values. /// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations. static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL, - SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - // Requires SSE2. - if (!Subtarget.hasSSE2()) - return SDValue(); + SelectionDAG &DAG, + const X86Subtarget &Subtarget) { +// Requires SSE2. +if (!Subtarget.hasSSE2()) + return SDValue(); - if (!N->getValueType(0).isVector() || !N->getValueType(0).isSimple()) - return SDValue(); +if (!N->getValueType(0).isVector() || !N->getValueType(0).isSimple()) + return SDValue(); - SDValue In = N->getOperand(0); - if (!In.getValueType().isSimple()) - return SDValue(); +SDValue In = N->getOperand(0); +if (!In.getValueType().isSimple()) + return SDValue(); - MVT VT = N->getValueType(0).getSimpleVT(); - MVT SVT = VT.getScalarType(); +MVT VT = N->getValueType(0).getSimpleVT(); +MVT SVT = VT.getScalarType(); - MVT InVT = In.getValueType().getSimpleVT(); - MVT InSVT = InVT.getScalarType(); +MVT InVT = In.getValueType().getSimpleVT(); +MVT InSVT = InVT.getScalarType(); - // Check we have a truncation suited for PACKSS/PACKUS. - if (!isPowerOf2_32(VT.getVectorNumElements())) - return SDValue(); - if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32) - return SDValue(); - if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64) - return SDValue(); +// Check we have a truncation suited for PACKSS/PACKUS. +if (!isPowerOf2_32(VT.getVectorNumElements())) + return SDValue(); +if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32) + return SDValue(); +if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64) + return SDValue(); - // Truncation to sub-128bit vXi32 can be better handled with shuffles. - if (SVT == MVT::i32 && VT.getSizeInBits() < 128) - return SDValue(); +// Truncation to sub-128bit vXi32 can be better handled with shuffles. +if (SVT == MVT::i32 && VT.getSizeInBits() < 128) + return SDValue(); - // AVX512 has fast truncate, but if the input is already going to be split, - // there's no harm in trying pack. - if (Subtarget.hasAVX512() && - !(!Subtarget.useAVX512Regs() && VT.is256BitVector() && - InVT.is512BitVector())) { - // PACK should still be worth it for 128-bit vectors if the sources were - // originally concatenated from subvectors. - SmallVector ConcatOps; - if (VT.getSizeInBits() > 128 || !collectConcatOps(In.getNode(), ConcatOps)) - return SDValue(); - } +// AVX512 has fast truncate, but if the input is already going to be split, +// there's no harm in trying pack. +if (Subtarget.hasAVX512() && + !(!Subtarget.useAVX512Regs() && VT.is256BitVector() && + InVT.is512BitVector())) { + // PACK should still be worth it for 128-bit vectors if the sources were + // originally concatenated from subvectors. + SmallVector ConcatOps; + if (VT.getSizeInBits() > 128 || !collectConcatOps(In.getNode(), ConcatOps)) + return SDValue(); +} - unsigned NumPackedSignBits = std::min(SVT.getSizeInBits(), 16); - unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8; +unsigned NumPackedSignBits = std::min(SVT.getSizeInBits(), 16); +unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8; - // Use PACKUS if the input has zero-bits that extend all the way to the - // packed/truncated value. e.g. masks, zext_in_reg, etc. - KnownBits Known = DAG.computeKnownBits(In); - unsigned NumLeadingZeroBits = Known.countMinLeadingZeros(); - if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedZeroBits)) - return truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget); - - // Use PACKSS if the input has sign-bits that extend all the way to the - // packed/truncated value. e.g. Comparison result, sext_in_reg, etc. - unsigned NumSignBits = DAG.ComputeNumSignBits(In); - - // Don't use PACKSS for vXi64 -> vXi32 truncations unless we're dealing with - // a sign splat. ComputeNumSignBits struggles to see through BITCASTs later - // on and combines/simplifications can't then use it. - if (SVT == MVT::i32 && NumSignBits != InSVT.getSizeInBits()) - return SDValue(); - - unsigned MinSignBits = InSVT.getSizeInBits() - NumPackedSignBits; - if (NumSignBits > MinSignBits) - return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget); - - // If we have a srl that only generates signbits that we will discard in - // the truncation then we can use PACKSS by converting the srl to a sra. - // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it. - if (In.getOpcode() == ISD::SRL && N->isOnlyUserOf(In.getNode())) - if (const APInt *ShAmt = DAG.getValidShiftAmountConstant( - In, APInt::getAllOnes(VT.getVectorNumElements()))) { - if (*ShAmt == MinSignBits) { - SDValue NewIn = DAG.getNode(ISD::SRA, DL, InVT, In->ops()); - return truncateVectorWithPACK(X86ISD::PACKSS, VT, NewIn, DL, DAG, - Subtarget); - } - } +// Use PACKUS if the input has zero-bits that extend all the way to the +// packed/truncated value. e.g. masks, zext_in_reg, etc. +KnownBits Known = DAG.computeKnownBits(In); +unsigned NumLeadingZeroBits = Known.countMinLeadingZeros(); +if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedZeroBits)) + return truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget); + +// Use PACKSS if the input has sign-bits that extend all the way to the +// packed/truncated value. e.g. Comparison result, sext_in_reg, etc. +unsigned NumSignBits = DAG.ComputeNumSignBits(In); +// Don't use PACKSS for vXi64 -> vXi32 truncations unless we're dealing with +// a sign splat. ComputeNumSignBits struggles to see through BITCASTs later +// on and combines/simplifications can't then use it. +if (SVT == MVT::i32 && NumSignBits != InSVT.getSizeInBits()) return SDValue(); + +unsigned MinSignBits = InSVT.getSizeInBits() - NumPackedSignBits; +if (NumSignBits > MinSignBits) + return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget); + +// If we have a srl that only generates signbits that we will discard in +// the truncation then we can use PACKSS by converting the srl to a sra. +// SimplifyDemandedBits often relaxes sra to srl so we need to reverse it. +if (In.getOpcode() == ISD::SRL && N->isOnlyUserOf(In.getNode())) + if (const APInt *ShAmt = DAG.getValidShiftAmountConstant( + In, APInt::getAllOnes(VT.getVectorNumElements()))) { + if (*ShAmt == MinSignBits) { + SDValue NewIn = DAG.getNode(ISD::SRA, DL, InVT, In->ops()); + return truncateVectorWithPACK(X86ISD::PACKSS, VT, NewIn, DL, DAG, + Subtarget); + } + } + +return SDValue(); } // Try to form a MULHU or MULHS node by looking for @@ -47839,49 +47868,49 @@ // of a MULHU/MULHS. There isn't a way to convey this to the generic DAG // combiner. static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, - SelectionDAG &DAG, const X86Subtarget &Subtarget) { - // First instruction should be a right shift of a multiply. - if (Src.getOpcode() != ISD::SRL || - Src.getOperand(0).getOpcode() != ISD::MUL) - return SDValue(); + SelectionDAG &DAG, const X86Subtarget &Subtarget) { +// First instruction should be a right shift of a multiply. +if (Src.getOpcode() != ISD::SRL || + Src.getOperand(0).getOpcode() != ISD::MUL) + return SDValue(); - if (!Subtarget.hasSSE2()) - return SDValue(); +if (!Subtarget.hasSSE2()) + return SDValue(); - // Only handle vXi16 types that are at least 128-bits unless they will be - // widened. - if (!VT.isVector() || VT.getVectorElementType() != MVT::i16) - return SDValue(); +// Only handle vXi16 types that are at least 128-bits unless they will be +// widened. +if (!VT.isVector() || VT.getVectorElementType() != MVT::i16) + return SDValue(); - // Input type should be at least vXi32. - EVT InVT = Src.getValueType(); - if (InVT.getVectorElementType().getSizeInBits() < 32) - return SDValue(); +// Input type should be at least vXi32. +EVT InVT = Src.getValueType(); +if (InVT.getVectorElementType().getSizeInBits() < 32) + return SDValue(); - // Need a shift by 16. - APInt ShiftAmt; - if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) || - ShiftAmt != 16) - return SDValue(); +// Need a shift by 16. +APInt ShiftAmt; +if (!ISD::isConstantSplatVector(Src.getOperand(1).getNode(), ShiftAmt) || + ShiftAmt != 16) + return SDValue(); - SDValue LHS = Src.getOperand(0).getOperand(0); - SDValue RHS = Src.getOperand(0).getOperand(1); +SDValue LHS = Src.getOperand(0).getOperand(0); +SDValue RHS = Src.getOperand(0).getOperand(1); - unsigned ExtOpc = LHS.getOpcode(); - if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) || - RHS.getOpcode() != ExtOpc) - return SDValue(); +unsigned ExtOpc = LHS.getOpcode(); +if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) || + RHS.getOpcode() != ExtOpc) + return SDValue(); - // Peek through the extends. - LHS = LHS.getOperand(0); - RHS = RHS.getOperand(0); +// Peek through the extends. +LHS = LHS.getOperand(0); +RHS = RHS.getOperand(0); - // Ensure the input types match. - if (LHS.getValueType() != VT || RHS.getValueType() != VT) - return SDValue(); +// Ensure the input types match. +if (LHS.getValueType() != VT || RHS.getValueType() != VT) + return SDValue(); - unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU; - return DAG.getNode(Opc, DL, VT, LHS, RHS); +unsigned Opc = ExtOpc == ISD::SIGN_EXTEND ? ISD::MULHS : ISD::MULHU; +return DAG.getNode(Opc, DL, VT, LHS, RHS); } // Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes @@ -47893,194 +47922,194 @@ // (i16 (ssat (add (mul (zext (even elts (i8 A))), (sext (even elts (i8 B)))), // (mul (zext (odd elts (i8 A)), (sext (odd elts (i8 B)))))))) static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG, - const X86Subtarget &Subtarget, - const SDLoc &DL) { - if (!VT.isVector() || !Subtarget.hasSSSE3()) - return SDValue(); - - unsigned NumElems = VT.getVectorNumElements(); - EVT ScalarVT = VT.getVectorElementType(); - if (ScalarVT != MVT::i16 || NumElems < 8 || !isPowerOf2_32(NumElems)) - return SDValue(); - - SDValue SSatVal = detectSSatPattern(In, VT); - if (!SSatVal || SSatVal.getOpcode() != ISD::ADD) - return SDValue(); - - // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs - // of multiplies from even/odd elements. - SDValue N0 = SSatVal.getOperand(0); - SDValue N1 = SSatVal.getOperand(1); - - if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) - return SDValue(); - - SDValue N00 = N0.getOperand(0); - SDValue N01 = N0.getOperand(1); - SDValue N10 = N1.getOperand(0); - SDValue N11 = N1.getOperand(1); + const X86Subtarget &Subtarget, + const SDLoc &DL) { +if (!VT.isVector() || !Subtarget.hasSSSE3()) + return SDValue(); - // TODO: Handle constant vectors and use knownbits/computenumsignbits? - // Canonicalize zero_extend to LHS. - if (N01.getOpcode() == ISD::ZERO_EXTEND) - std::swap(N00, N01); - if (N11.getOpcode() == ISD::ZERO_EXTEND) - std::swap(N10, N11); +unsigned NumElems = VT.getVectorNumElements(); +EVT ScalarVT = VT.getVectorElementType(); +if (ScalarVT != MVT::i16 || NumElems < 8 || !isPowerOf2_32(NumElems)) + return SDValue(); - // Ensure we have a zero_extend and a sign_extend. - if (N00.getOpcode() != ISD::ZERO_EXTEND || - N01.getOpcode() != ISD::SIGN_EXTEND || - N10.getOpcode() != ISD::ZERO_EXTEND || - N11.getOpcode() != ISD::SIGN_EXTEND) - return SDValue(); +SDValue SSatVal = detectSSatPattern(In, VT); +if (!SSatVal || SSatVal.getOpcode() != ISD::ADD) + return SDValue(); - // Peek through the extends. - N00 = N00.getOperand(0); - N01 = N01.getOperand(0); - N10 = N10.getOperand(0); - N11 = N11.getOperand(0); +// Ok this is a signed saturation of an ADD. See if this ADD is adding pairs +// of multiplies from even/odd elements. +SDValue N0 = SSatVal.getOperand(0); +SDValue N1 = SSatVal.getOperand(1); - // Ensure the extend is from vXi8. - if (N00.getValueType().getVectorElementType() != MVT::i8 || - N01.getValueType().getVectorElementType() != MVT::i8 || - N10.getValueType().getVectorElementType() != MVT::i8 || - N11.getValueType().getVectorElementType() != MVT::i8) - return SDValue(); +if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) + return SDValue(); - // All inputs should be build_vectors. - if (N00.getOpcode() != ISD::BUILD_VECTOR || - N01.getOpcode() != ISD::BUILD_VECTOR || - N10.getOpcode() != ISD::BUILD_VECTOR || - N11.getOpcode() != ISD::BUILD_VECTOR) - return SDValue(); +SDValue N00 = N0.getOperand(0); +SDValue N01 = N0.getOperand(1); +SDValue N10 = N1.getOperand(0); +SDValue N11 = N1.getOperand(1); + +// TODO: Handle constant vectors and use knownbits/computenumsignbits? +// Canonicalize zero_extend to LHS. +if (N01.getOpcode() == ISD::ZERO_EXTEND) + std::swap(N00, N01); +if (N11.getOpcode() == ISD::ZERO_EXTEND) + std::swap(N10, N11); + +// Ensure we have a zero_extend and a sign_extend. +if (N00.getOpcode() != ISD::ZERO_EXTEND || + N01.getOpcode() != ISD::SIGN_EXTEND || + N10.getOpcode() != ISD::ZERO_EXTEND || + N11.getOpcode() != ISD::SIGN_EXTEND) + return SDValue(); - // N00/N10 are zero extended. N01/N11 are sign extended. +// Peek through the extends. +N00 = N00.getOperand(0); +N01 = N01.getOperand(0); +N10 = N10.getOperand(0); +N11 = N11.getOperand(0); + +// Ensure the extend is from vXi8. +if (N00.getValueType().getVectorElementType() != MVT::i8 || + N01.getValueType().getVectorElementType() != MVT::i8 || + N10.getValueType().getVectorElementType() != MVT::i8 || + N11.getValueType().getVectorElementType() != MVT::i8) + return SDValue(); - // For each element, we need to ensure we have an odd element from one vector - // multiplied by the odd element of another vector and the even element from - // one of the same vectors being multiplied by the even element from the - // other vector. So we need to make sure for each element i, this operator - // is being performed: - // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1] - SDValue ZExtIn, SExtIn; - for (unsigned i = 0; i != NumElems; ++i) { - SDValue N00Elt = N00.getOperand(i); - SDValue N01Elt = N01.getOperand(i); - SDValue N10Elt = N10.getOperand(i); - SDValue N11Elt = N11.getOperand(i); - // TODO: Be more tolerant to undefs. - if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || - N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || - N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || - N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT) - return SDValue(); - auto *ConstN00Elt = dyn_cast(N00Elt.getOperand(1)); - auto *ConstN01Elt = dyn_cast(N01Elt.getOperand(1)); - auto *ConstN10Elt = dyn_cast(N10Elt.getOperand(1)); - auto *ConstN11Elt = dyn_cast(N11Elt.getOperand(1)); - if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt) - return SDValue(); - unsigned IdxN00 = ConstN00Elt->getZExtValue(); - unsigned IdxN01 = ConstN01Elt->getZExtValue(); - unsigned IdxN10 = ConstN10Elt->getZExtValue(); - unsigned IdxN11 = ConstN11Elt->getZExtValue(); - // Add is commutative so indices can be reordered. - if (IdxN00 > IdxN10) { - std::swap(IdxN00, IdxN10); - std::swap(IdxN01, IdxN11); - } - // N0 indices be the even element. N1 indices must be the next odd element. - if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || - IdxN01 != 2 * i || IdxN11 != 2 * i + 1) - return SDValue(); - SDValue N00In = N00Elt.getOperand(0); - SDValue N01In = N01Elt.getOperand(0); - SDValue N10In = N10Elt.getOperand(0); - SDValue N11In = N11Elt.getOperand(0); - // First time we find an input capture it. - if (!ZExtIn) { - ZExtIn = N00In; - SExtIn = N01In; - } - if (ZExtIn != N00In || SExtIn != N01In || - ZExtIn != N10In || SExtIn != N11In) - return SDValue(); - } +// All inputs should be build_vectors. +if (N00.getOpcode() != ISD::BUILD_VECTOR || + N01.getOpcode() != ISD::BUILD_VECTOR || + N10.getOpcode() != ISD::BUILD_VECTOR || + N11.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); - auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, - ArrayRef Ops) { - // Shrink by adding truncate nodes and let DAGCombine fold with the - // sources. - EVT InVT = Ops[0].getValueType(); - assert(InVT.getScalarType() == MVT::i8 && - "Unexpected scalar element type"); - assert(InVT == Ops[1].getValueType() && "Operands' types mismatch"); - EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, - InVT.getVectorNumElements() / 2); - return DAG.getNode(X86ISD::VPMADDUBSW, DL, ResVT, Ops[0], Ops[1]); - }; - return SplitOpsAndApply(DAG, Subtarget, DL, VT, { ZExtIn, SExtIn }, - PMADDBuilder); +// N00/N10 are zero extended. N01/N11 are sign extended. + +// For each element, we need to ensure we have an odd element from one vector +// multiplied by the odd element of another vector and the even element from +// one of the same vectors being multiplied by the even element from the +// other vector. So we need to make sure for each element i, this operator +// is being performed: +// A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1] +SDValue ZExtIn, SExtIn; +for (unsigned i = 0; i != NumElems; ++i) { + SDValue N00Elt = N00.getOperand(i); + SDValue N01Elt = N01.getOperand(i); + SDValue N10Elt = N10.getOperand(i); + SDValue N11Elt = N11.getOperand(i); + // TODO: Be more tolerant to undefs. + if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + auto *ConstN00Elt = dyn_cast(N00Elt.getOperand(1)); + auto *ConstN01Elt = dyn_cast(N01Elt.getOperand(1)); + auto *ConstN10Elt = dyn_cast(N10Elt.getOperand(1)); + auto *ConstN11Elt = dyn_cast(N11Elt.getOperand(1)); + if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt) + return SDValue(); + unsigned IdxN00 = ConstN00Elt->getZExtValue(); + unsigned IdxN01 = ConstN01Elt->getZExtValue(); + unsigned IdxN10 = ConstN10Elt->getZExtValue(); + unsigned IdxN11 = ConstN11Elt->getZExtValue(); + // Add is commutative so indices can be reordered. + if (IdxN00 > IdxN10) { + std::swap(IdxN00, IdxN10); + std::swap(IdxN01, IdxN11); + } + // N0 indices be the even element. N1 indices must be the next odd element. + if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || + IdxN01 != 2 * i || IdxN11 != 2 * i + 1) + return SDValue(); + SDValue N00In = N00Elt.getOperand(0); + SDValue N01In = N01Elt.getOperand(0); + SDValue N10In = N10Elt.getOperand(0); + SDValue N11In = N11Elt.getOperand(0); + // First time we find an input capture it. + if (!ZExtIn) { + ZExtIn = N00In; + SExtIn = N01In; + } + if (ZExtIn != N00In || SExtIn != N01In || + ZExtIn != N10In || SExtIn != N11In) + return SDValue(); +} + +auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, + ArrayRef Ops) { + // Shrink by adding truncate nodes and let DAGCombine fold with the + // sources. + EVT InVT = Ops[0].getValueType(); + assert(InVT.getScalarType() == MVT::i8 && + "Unexpected scalar element type"); + assert(InVT == Ops[1].getValueType() && "Operands' types mismatch"); + EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + InVT.getVectorNumElements() / 2); + return DAG.getNode(X86ISD::VPMADDUBSW, DL, ResVT, Ops[0], Ops[1]); +}; +return SplitOpsAndApply(DAG, Subtarget, DL, VT, { ZExtIn, SExtIn }, + PMADDBuilder); } static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { - EVT VT = N->getValueType(0); - SDValue Src = N->getOperand(0); - SDLoc DL(N); + const X86Subtarget &Subtarget) { +EVT VT = N->getValueType(0); +SDValue Src = N->getOperand(0); +SDLoc DL(N); - // Attempt to pre-truncate inputs to arithmetic ops instead. - if (SDValue V = combineTruncatedArithmetic(N, DAG, Subtarget, DL)) - return V; +// Attempt to pre-truncate inputs to arithmetic ops instead. +if (SDValue V = combineTruncatedArithmetic(N, DAG, Subtarget, DL)) + return V; - // Try to detect AVG pattern first. - if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL)) - return Avg; +// Try to detect AVG pattern first. +if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL)) + return Avg; - // Try to detect PMADD - if (SDValue PMAdd = detectPMADDUBSW(Src, VT, DAG, Subtarget, DL)) - return PMAdd; +// Try to detect PMADD +if (SDValue PMAdd = detectPMADDUBSW(Src, VT, DAG, Subtarget, DL)) + return PMAdd; - // Try to combine truncation with signed/unsigned saturation. - if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget)) - return Val; +// Try to combine truncation with signed/unsigned saturation. +if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget)) + return Val; - // Try to combine PMULHUW/PMULHW for vXi16. - if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget)) - return V; +// Try to combine PMULHUW/PMULHW for vXi16. +if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget)) + return V; - // The bitcast source is a direct mmx result. - // Detect bitcasts between i32 to x86mmx - if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) { - SDValue BCSrc = Src.getOperand(0); - if (BCSrc.getValueType() == MVT::x86mmx) - return DAG.getNode(X86ISD::MMX_MOVD2W, DL, MVT::i32, BCSrc); - } +// The bitcast source is a direct mmx result. +// Detect bitcasts between i32 to x86mmx +if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) { + SDValue BCSrc = Src.getOperand(0); + if (BCSrc.getValueType() == MVT::x86mmx) + return DAG.getNode(X86ISD::MMX_MOVD2W, DL, MVT::i32, BCSrc); +} - // Try to truncate extended sign/zero bits with PACKSS/PACKUS. - if (SDValue V = combineVectorSignBitsTruncation(N, DL, DAG, Subtarget)) - return V; +// Try to truncate extended sign/zero bits with PACKSS/PACKUS. +if (SDValue V = combineVectorSignBitsTruncation(N, DL, DAG, Subtarget)) + return V; - return combineVectorTruncation(N, DAG, Subtarget); +return combineVectorTruncation(N, DAG, Subtarget); } static SDValue combineVTRUNC(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { - EVT VT = N->getValueType(0); - SDValue In = N->getOperand(0); - SDLoc DL(N); + TargetLowering::DAGCombinerInfo &DCI) { +EVT VT = N->getValueType(0); +SDValue In = N->getOperand(0); +SDLoc DL(N); - if (auto SSatVal = detectSSatPattern(In, VT)) - return DAG.getNode(X86ISD::VTRUNCS, DL, VT, SSatVal); - if (auto USatVal = detectUSatPattern(In, VT, DAG, DL)) - return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, USatVal); +if (auto SSatVal = detectSSatPattern(In, VT)) + return DAG.getNode(X86ISD::VTRUNCS, DL, VT, SSatVal); +if (auto USatVal = detectUSatPattern(In, VT, DAG, DL)) + return DAG.getNode(X86ISD::VTRUNCUS, DL, VT, USatVal); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - APInt DemandedMask(APInt::getAllOnes(VT.getScalarSizeInBits())); - if (TLI.SimplifyDemandedBits(SDValue(N, 0), DemandedMask, DCI)) - return SDValue(N, 0); +const TargetLowering &TLI = DAG.getTargetLoweringInfo(); +APInt DemandedMask(APInt::getAllOnes(VT.getScalarSizeInBits())); +if (TLI.SimplifyDemandedBits(SDValue(N, 0), DemandedMask, DCI)) + return SDValue(N, 0); - return SDValue(); +return SDValue(); } /// Returns the negated value if the node \p N flips sign of FP value. @@ -48093,68 +48122,69 @@ /// This also recognizes splat of a negated value and returns the splat of that /// value. static SDValue isFNEG(SelectionDAG &DAG, SDNode *N, unsigned Depth = 0) { - if (N->getOpcode() == ISD::FNEG) - return N->getOperand(0); +if (N->getOpcode() == ISD::FNEG) + return N->getOperand(0); - // Don't recurse exponentially. - if (Depth > SelectionDAG::MaxRecursionDepth) - return SDValue(); +// Don't recurse exponentially. +if (Depth > SelectionDAG::MaxRecursionDepth) + return SDValue(); - unsigned ScalarSize = N->getValueType(0).getScalarSizeInBits(); +unsigned ScalarSize = N->getValueType(0).getScalarSizeInBits(); - SDValue Op = peekThroughBitcasts(SDValue(N, 0)); - EVT VT = Op->getValueType(0); +SDValue Op = peekThroughBitcasts(SDValue(N, 0)); +EVT VT = Op->getValueType(0); - // Make sure the element size doesn't change. - if (VT.getScalarSizeInBits() != ScalarSize) - return SDValue(); +// Make sure the element size doesn't change. +if (VT.getScalarSizeInBits() != ScalarSize) + return SDValue(); - unsigned Opc = Op.getOpcode(); - switch (Opc) { - case ISD::VECTOR_SHUFFLE: { - // For a VECTOR_SHUFFLE(VEC1, VEC2), if the VEC2 is undef, then the negate - // of this is VECTOR_SHUFFLE(-VEC1, UNDEF). The mask can be anything here. - if (!Op.getOperand(1).isUndef()) - return SDValue(); - if (SDValue NegOp0 = isFNEG(DAG, Op.getOperand(0).getNode(), Depth + 1)) - if (NegOp0.getValueType() == VT) // FIXME: Can we do better? - return DAG.getVectorShuffle(VT, SDLoc(Op), NegOp0, DAG.getUNDEF(VT), - cast(Op)->getMask()); - break; - } - case ISD::INSERT_VECTOR_ELT: { - // Negate of INSERT_VECTOR_ELT(UNDEF, V, INDEX) is INSERT_VECTOR_ELT(UNDEF, - // -V, INDEX). - SDValue InsVector = Op.getOperand(0); - SDValue InsVal = Op.getOperand(1); - if (!InsVector.isUndef()) - return SDValue(); - if (SDValue NegInsVal = isFNEG(DAG, InsVal.getNode(), Depth + 1)) - if (NegInsVal.getValueType() == VT.getVectorElementType()) // FIXME - return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Op), VT, InsVector, - NegInsVal, Op.getOperand(2)); - break; - } - case ISD::FSUB: - case ISD::XOR: - case X86ISD::FXOR: { - SDValue Op1 = Op.getOperand(1); - SDValue Op0 = Op.getOperand(0); +unsigned Opc = Op.getOpcode(); +switch (Opc) { +case ISD::VECTOR_SHUFFLE: { + // For a VECTOR_SHUFFLE(VEC1, VEC2), if the VEC2 is undef, then the negate + // of this is VECTOR_SHUFFLE(-VEC1, UNDEF). The mask can be anything here. + if (!Op.getOperand(1).isUndef()) + return SDValue(); + if (SDValue NegOp0 = isFNEG(DAG, Op.getOperand(0).getNode(), Depth + 1)) + if (NegOp0.getValueType() == VT) // FIXME: Can we do better? + return DAG.getVectorShuffle(VT, SDLoc(Op), NegOp0, DAG.getUNDEF(VT), + cast(Op)->getMask()); + break; +} +case ISD::INSERT_VECTOR_ELT: { + // Negate of INSERT_VECTOR_ELT(UNDEF, V, INDEX) is INSERT_VECTOR_ELT(UNDEF, + // -V, INDEX). + SDValue InsVector = Op.getOperand(0); + SDValue InsVal = Op.getOperand(1); + if (!InsVector.isUndef()) + return SDValue(); + if (SDValue NegInsVal = isFNEG(DAG, InsVal.getNode(), Depth + 1)) + if (NegInsVal.getValueType() == VT.getVectorElementType()) // FIXME + return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Op), VT, InsVector, + NegInsVal, Op.getOperand(2)); + break; +} +case ISD::FSUB: +case ISD::XOR: +case X86ISD::FXOR: { + SDValue Op1 = Op.getOperand(1); + SDValue Op0 = Op.getOperand(0); - // For XOR and FXOR, we want to check if constant - // bits of Op1 are sign bit masks. For FSUB, we - // have to check if constant bits of Op0 are sign - // bit masks and hence we swap the operands. - if (Opc == ISD::FSUB) - std::swap(Op0, Op1); + // For XOR and FXOR, we want to check if constant + // bits of Op1 are sign bit masks. For FSUB, we + // have to check if constant bits of Op0 are sign + // bit masks and hence we swap the operands. + if (Opc == ISD::FSUB) + std::swap(Op0, Op1); - APInt UndefElts; - SmallVector EltBits; - // Extract constant bits and see if they are all - // sign bit masks. Ignore the undef elements. - if (getTargetConstantBitsFromNode(Op1, ScalarSize, UndefElts, EltBits, - /* AllowWholeUndefs */ true, - /* AllowPartialUndefs */ false)) { + APInt UndefElts; + SmallVector EltBits; + // Extract constant bits and see if they are all + // sign bit masks. Ignore the undef elements. + if (getTargetConstantBitsFromNode(Op1, ScalarSize, UndefElts, EltBits, + /* AllowWholeUndefs */ + true, + /* AllowPartialUndefs */ false)) { for (unsigned I = 0, E = EltBits.size(); I < E; I++) if (!UndefElts[I] && !EltBits[I].isSignMask()) return SDValue(); diff --git a/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll b/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll @@ -0,0 +1,277 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx512fp16 | FileCheck %s + +; FADD(acc, FMA(a, b, 0.0)) can be combined to FMA(a, b, acc) if the fast-math flag set. +define dso_local <32 x half> @test1(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test1: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call fast <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> zeroinitializer, <16 x float> %0, <16 x float> %1, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd fast <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <32 x half> @test2(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test2: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call fast <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> zeroinitializer, <16 x float> %0, <16 x float> %1, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd fast <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <16 x half> @test3(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test3: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call fast <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> zeroinitializer, <8 x float> %0, <8 x float> %1, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd fast <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <16 x half> @test4(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test4: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call fast <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> zeroinitializer, <8 x float> %0, <8 x float> %1, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd fast <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <8 x half> @test5(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test5: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call fast <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> zeroinitializer, <4 x float> %0, <4 x float> %1, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd fast <8 x half> %3, %acc + ret <8 x half> %add.i +} + +define dso_local <8 x half> @test6(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test6: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call fast <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> zeroinitializer, <4 x float> %0, <4 x float> %1, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd fast <8 x half> %3, %acc + ret <8 x half> %add.i +} + +; FADD(acc, FMA(a, b, 0.0)) shouldn't be combined to FMA(a, b, acc) if the fast-math flag unset. +define dso_local <32 x half> @test7(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test7: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; CHECK-NEXT: vfcmaddcph %zmm2, %zmm1, %zmm3 +; CHECK-NEXT: vaddph %zmm0, %zmm3, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> zeroinitializer, <16 x float> %0, <16 x float> %1, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <32 x half> @test8(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test8: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; CHECK-NEXT: vfmaddcph %zmm2, %zmm1, %zmm3 +; CHECK-NEXT: vaddph %zmm0, %zmm3, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> zeroinitializer, <16 x float> %0, <16 x float> %1, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <16 x half> @test9(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test9: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; CHECK-NEXT: vfcmaddcph %ymm2, %ymm1, %ymm3 +; CHECK-NEXT: vaddph %ymm0, %ymm3, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> zeroinitializer, <8 x float> %0, <8 x float> %1, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <16 x half> @test10(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test10: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; CHECK-NEXT: vfmaddcph %ymm2, %ymm1, %ymm3 +; CHECK-NEXT: vaddph %ymm0, %ymm3, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> zeroinitializer, <8 x float> %0, <8 x float> %1, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <8 x half> @test11(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test11: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; CHECK-NEXT: vfcmaddcph %xmm2, %xmm1, %xmm3 +; CHECK-NEXT: vaddph %xmm0, %xmm3, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> zeroinitializer, <4 x float> %0, <4 x float> %1, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd <8 x half> %3, %acc + ret <8 x half> %add.i +} + +define dso_local <8 x half> @test12(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test12: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vxorps %xmm3, %xmm3, %xmm3 +; CHECK-NEXT: vfmaddcph %xmm2, %xmm1, %xmm3 +; CHECK-NEXT: vaddph %xmm0, %xmm3, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> zeroinitializer, <4 x float> %0, <4 x float> %1, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd <8 x half> %3, %acc + ret <8 x half> %add.i +} + +; FADD(acc, FMA(a, b, -0.0)) can be combined to FMA(a, b, acc) even if the fast-math flag of FMA is not set. +; The following testcases cannot be generated by the front-end. +define dso_local <32 x half> @test13(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test13: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> , <16 x float> %0, <16 x float> %1, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd fast <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <32 x half> @test14(<32 x half> %acc, <32 x half> %a, <32 x half> %b) { +; CHECK-LABEL: test14: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <32 x half> %a to <16 x float> + %1 = bitcast <32 x half> %b to <16 x float> + %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> , <16 x float> %0, <16 x float> %1, i16 -1, i32 4) + %3 = bitcast <16 x float> %2 to <32 x half> + %add.i = fadd fast <32 x half> %3, %acc + ret <32 x half> %add.i +} + +define dso_local <16 x half> @test15(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test15: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> , <8 x float> %0, <8 x float> %1, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd fast <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <16 x half> @test16(<16 x half> %acc, <16 x half> %a, <16 x half> %b) { +; CHECK-LABEL: test16: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <16 x half> %a to <8 x float> + %1 = bitcast <16 x half> %b to <8 x float> + %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> , <8 x float> %0, <8 x float> %1, i8 -1) + %3 = bitcast <8 x float> %2 to <16 x half> + %add.i = fadd fast <16 x half> %3, %acc + ret <16 x half> %add.i +} + +define dso_local <8 x half> @test17(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test17: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfcmaddcph %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> , <4 x float> %0, <4 x float> %1, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd fast <8 x half> %3, %acc + ret <8 x half> %add.i +} + +define dso_local <8 x half> @test18(<8 x half> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: test18: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vfmaddcph %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = bitcast <8 x half> %a to <4 x float> + %1 = bitcast <8 x half> %b to <4 x float> + %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> , <4 x float> %0, <4 x float> %1, i8 -1) + %3 = bitcast <4 x float> %2 to <8 x half> + %add.i = fadd fast <8 x half> %3, %acc + ret <8 x half> %add.i +} + +declare <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg) +declare <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg) +declare <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8) +declare <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8) +declare <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8) +declare <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8)