Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -31117,6 +31117,115 @@ return SDValue(); } +/// Returns the negated value if the node \p N flips sign of FP value. +/// +/// FP-negation node may have different forms: FNEG(x), FXOR (x, 0x80000000) +/// or FSUB(0, x) +/// AVX512F does not have FXOR, so FNEG is lowered as +/// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))). +/// In this case we go though all bitcasts. +static SDValue isFNEG(SDNode *N) { + if (N->getOpcode() == ISD::FNEG) + return N->getOperand(0); + + SDValue Op = peekThroughBitcasts(SDValue(N, 0)); + auto Opc = Op.getOpcode(); + if (Opc != X86ISD::FXOR && Opc != ISD::XOR && Opc != ISD::FSUB) + return SDValue(); + + SDValue Op1 = peekThroughBitcasts(Op.getOperand(1)); + if (!Op1.getValueType().isFloatingPoint()) + return SDValue(); + + SDValue Op0 = peekThroughBitcasts(Op.getOperand(0)); + + unsigned EltBits = Op1.getScalarValueSizeInBits(); + auto isSignMask = [&](const ConstantFP *C) { + return C->getValueAPF().bitcastToAPInt() == APInt::getSignMask(EltBits); + }; + + // There is more than one way to represent the same constant on + // the different X86 targets. The type of the node may also depend on size. + // - load scalar value and broadcast + // - BUILD_VECTOR node + // - load from a constant pool. + // We check all variants here. + auto IsNeg = [=](const ConstantFP *Val) { + return (isSignMask(Val) && Opc != ISD::FSUB) || + (Val->isZero() && Opc == ISD::FSUB); + }; + + auto Negate = [=](SDValue Op0, SDValue Op1) { + if (Op1.getOpcode() == X86ISD::VBROADCAST) { + if (auto *C = cast_or_null( + getTargetConstantFromNode(Op1.getOperand(0)))) + if (IsNeg(C)) + return Op0; + + } else if (BuildVectorSDNode *BV = dyn_cast(Op1)) { + if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) + if (IsNeg(CN->getConstantFPValue())) + return Op0; + + } else if (auto *C = getTargetConstantFromNode(Op1)) { + if (C->getType()->isVectorTy()) { + if (auto *SplatV = cast_or_null(C->getSplatValue())) + if (IsNeg(SplatV)) + return Op0; + } else if (auto *FPConst = dyn_cast(C)) + if (IsNeg(FPConst)) + return Op0; + } + return SDValue(); + }; + if (Opc == ISD::FSUB) + return Negate(Op1, Op0); + else + return Negate(Op0, Op1); +} + +/// Try to combine a shuffle to a negate if the shuffle feeds to a FMA. +/// +/// This converts the sequence: +/// t1 = vneg t0 +/// t2 = vector_shuffle<...> t1, undef +/// into +/// t3 = vector_shuffle<...> t0, undef +/// t4 = vneg t3 +/// +/// when one of the uses of t2 is an fma and hasAnyFMA() is true for the +/// sub-target. The idea here is that the vneg can be combined with fma to +/// generate fnmadd/fmsub etc. +static SDValue combineShuffleToNegate(SelectionDAG &DAG, SDNode *N, + const X86Subtarget &Subtarget) { + // This transform is intended to enable combining of VNEG and FMA. + if (!Subtarget.hasAnyFMA()) + return SDValue(); + bool UsedByFMA = false; + // Check if the shuffle is used by an FMA. + for (auto *User : N->uses()) + if (User->getOpcode() == ISD::FMA) { + UsedByFMA = true; + break; + } + if (!UsedByFMA) + return SDValue(); + + ShuffleVectorSDNode *SVOp = dyn_cast(N); + auto NegOp0 = isFNEG(N->getOperand(0).getNode()); + + if (SVOp && SVOp->isSplat() && NegOp0.getNode() && + N->getOperand(1).isUndef()) { + EVT VT = N->getValueType(0); + SDLoc DL(N); + SDValue Shuffle = + DAG.getVectorShuffle(VT, DL, NegOp0, DAG.getUNDEF(VT), SVOp->getMask()); + return DAG.getNode(ISD::FNEG, DL, VT, Shuffle); + } + + return SDValue(); +} + static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { @@ -31131,6 +31240,9 @@ if (SDValue HAddSub = foldShuffleOfHorizOp(N)) return HAddSub; + + if (SDValue NegShuffle = combineShuffleToNegate(DAG, N, Subtarget)) + return NegShuffle; } // During Type Legalization, when promoting illegal vector types, @@ -36694,59 +36806,6 @@ return combineVectorTruncation(N, DAG, Subtarget); } -/// Returns the negated value if the node \p N flips sign of FP value. -/// -/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000). -/// AVX512F does not have FXOR, so FNEG is lowered as -/// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))). -/// In this case we go though all bitcasts. -static SDValue isFNEG(SDNode *N) { - if (N->getOpcode() == ISD::FNEG) - return N->getOperand(0); - - SDValue Op = peekThroughBitcasts(SDValue(N, 0)); - if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR) - return SDValue(); - - SDValue Op1 = peekThroughBitcasts(Op.getOperand(1)); - if (!Op1.getValueType().isFloatingPoint()) - return SDValue(); - - SDValue Op0 = peekThroughBitcasts(Op.getOperand(0)); - - unsigned EltBits = Op1.getScalarValueSizeInBits(); - auto isSignMask = [&](const ConstantFP *C) { - return C->getValueAPF().bitcastToAPInt() == APInt::getSignMask(EltBits); - }; - - // There is more than one way to represent the same constant on - // the different X86 targets. The type of the node may also depend on size. - // - load scalar value and broadcast - // - BUILD_VECTOR node - // - load from a constant pool. - // We check all variants here. - if (Op1.getOpcode() == X86ISD::VBROADCAST) { - if (auto *C = getTargetConstantFromNode(Op1.getOperand(0))) - if (isSignMask(cast(C))) - return Op0; - - } else if (BuildVectorSDNode *BV = dyn_cast(Op1)) { - if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) - if (isSignMask(CN->getConstantFPValue())) - return Op0; - - } else if (auto *C = getTargetConstantFromNode(Op1)) { - if (C->getType()->isVectorTy()) { - if (auto *SplatV = C->getSplatValue()) - if (isSignMask(cast(SplatV))) - return Op0; - } else if (auto *FPConst = dyn_cast(C)) - if (isSignMask(FPConst)) - return Op0; - } - return SDValue(); -} - /// Do target-specific dag combines on floating point negations. static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { Index: test/CodeGen/X86/avx2-fma-fneg-combine.ll =================================================================== --- test/CodeGen/X86/avx2-fma-fneg-combine.ll +++ test/CodeGen/X86/avx2-fma-fneg-combine.ll @@ -115,3 +115,25 @@ declare <2 x double> @llvm.x86.fma.vfmadd.pd(<2 x double> %a, <2 x double> %b, <2 x double> %c) +define <8 x float> @test7(float %a, <8 x float> %b, <8 x float> %c) { +; X32-LABEL: test7: +; X32: # %bb.0: # %entry +; X32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %ymm2 +; X32-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm2 * ymm0) + ymm1 +; X32-NEXT: retl +; +; X64-LABEL: test7: +; X64: # %bb.0: # %entry +; X64-NEXT: vbroadcastss %xmm0, %ymm0 +; X64-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm2 +; X64-NEXT: retq +entry: + %0 = insertelement <8 x float> undef, float %a, i32 0 + %1 = fsub <8 x float> , %0 + %2 = shufflevector <8 x float> %1, <8 x float> undef, <8 x i32> zeroinitializer + %3 = tail call <8 x float> @llvm.fma.v8f32(<8 x float> %2, <8 x float> %b, <8 x float> %c) + ret <8 x float> %3 + +} + +declare <8 x float> @llvm.fma.v8f32(<8 x float> %a, <8 x float> %b, <8 x float> %c)