Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -29233,28 +29233,6 @@ return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones); } -static SDValue combineXor(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { - if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) - return Cmp; - - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - - if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG)) - return RV; - - if (Subtarget.hasCMov()) - if (SDValue RV = combineIntegerAbs(N, DAG)) - return RV; - - if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) - return FPLogic; - - return SDValue(); -} - /// This function detects the AVG pattern between vectors of unsigned i8/i16, /// which is c = (a + b + 1) / 2, and replace this operation with the efficient /// X86ISD::AVG instruction. @@ -30363,12 +30341,68 @@ return combineVectorTruncation(N, DAG, Subtarget); } +/// Returns the negated value if the node \p N flips sign of FP value. +/// +/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000). +/// AVX512F does not have FXOR, so FNEG is lowered as +/// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))). +/// In this case we go though all bitcasts. +static SDValue isFNEG(SDNode *N) { + if (N->getOpcode() == ISD::FNEG) + return N->getOperand(0); + + SDValue Op = peekThroughBitcasts(SDValue(N, 0)); + if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR) + return SDValue(); + + SDValue Op1 = peekThroughBitcasts(Op.getOperand(1)); + if (!Op1.getValueType().isFloatingPoint()) + return SDValue(); + + SDValue Op0 = peekThroughBitcasts(Op.getOperand(0)); + + unsigned EltBits = Op1.getValueType().getScalarSizeInBits(); + auto isSignBitValue = [&](const ConstantFP *C) { + return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits); + }; + + // There is more than one way to represent the same constant on + // the different X86 targets. The type of the node may also depend on size. + // - load scalar value and broadcast + // - BUILD_VECTOR node + // - load from a constant pool. + // We check all variants here. + if (Op1.getOpcode() == X86ISD::VBROADCAST) { + if (auto *C = getTargetConstantFromNode(Op1.getOperand(0))) + if (isSignBitValue(cast(C))) + return Op0; + + } else if (BuildVectorSDNode *BV = dyn_cast(Op1)) { + if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) + if (isSignBitValue(CN->getConstantFPValue())) + return Op0; + + } else if (auto *C = getTargetConstantFromNode(Op1)) { + if (C->getType()->isVectorTy()) { + if (auto *SplatV = C->getSplatValue()) + if (isSignBitValue(cast(SplatV))) + return Op0; + } else if (auto *FPConst = dyn_cast(C)) + if (isSignBitValue(FPConst)) + return Op0; + } + return SDValue(); +} + /// Do target-specific dag combines on floating point negations. static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { - EVT VT = N->getValueType(0); + EVT OrigVT = N->getValueType(0); + SDValue Arg = isFNEG(N); + assert(Arg.getNode() && "N is expected to be an FNEG node"); + + EVT VT = Arg.getValueType(); EVT SVT = VT.getScalarType(); - SDValue Arg = N->getOperand(0); SDLoc DL(N); // Let legalize expand this if it isn't a legal type yet. @@ -30381,40 +30415,30 @@ if (Arg.getOpcode() == ISD::FMUL && (SVT == MVT::f32 || SVT == MVT::f64) && Arg->getFlags()->hasNoSignedZeros() && Subtarget.hasAnyFMA()) { SDValue Zero = DAG.getConstantFP(0.0, DL, VT); - return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Zero); + SDValue NewNode = DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Zero); + return DAG.getBitcast(OrigVT, NewNode); } // If we're negating a FMA node, then we can adjust the // instruction to include the extra negation. + unsigned NewOpcode = 0; if (Arg.hasOneUse()) { switch (Arg.getOpcode()) { - case X86ISD::FMADD: - return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); - case X86ISD::FMSUB: - return DAG.getNode(X86ISD::FNMADD, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); - case X86ISD::FNMADD: - return DAG.getNode(X86ISD::FMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); - case X86ISD::FNMSUB: - return DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); - case X86ISD::FMADD_RND: - return DAG.getNode(X86ISD::FNMSUB_RND, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); - case X86ISD::FMSUB_RND: - return DAG.getNode(X86ISD::FNMADD_RND, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); - case X86ISD::FNMADD_RND: - return DAG.getNode(X86ISD::FMSUB_RND, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); - case X86ISD::FNMSUB_RND: - return DAG.getNode(X86ISD::FMADD_RND, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + case X86ISD::FMADD: NewOpcode = X86ISD::FNMSUB; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FMSUB; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FMADD; break; + case X86ISD::FMADD_RND: NewOpcode = X86ISD::FNMSUB_RND; break; + case X86ISD::FMSUB_RND: NewOpcode = X86ISD::FNMADD_RND; break; + case X86ISD::FNMADD_RND: NewOpcode = X86ISD::FMSUB_RND; break; + case X86ISD::FNMSUB_RND: NewOpcode = X86ISD::FMADD_RND; break; } } + if (NewOpcode) + return DAG.getBitcast(OrigVT, DAG.getNode(NewOpcode, DL, VT, + Arg.getNode()->ops())); + return SDValue(); } @@ -30442,42 +30466,28 @@ return SDValue(); } -/// Returns true if the node \p N is FNEG(x) or FXOR (x, 0x80000000). -bool isFNEG(const SDNode *N) { - if (N->getOpcode() == ISD::FNEG) - return true; +static SDValue combineXor(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) + return Cmp; - if (N->getOpcode() == X86ISD::FXOR) { - unsigned EltBits = N->getSimpleValueType(0).getScalarSizeInBits(); - SDValue Op1 = N->getOperand(1); + if (DCI.isBeforeLegalizeOps()) + return SDValue(); - auto isSignBitValue = [&](const ConstantFP *C) { - return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits); - }; + if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG)) + return RV; - // There is more than one way to represent the same constant on - // the different X86 targets. The type of the node may also depend on size. - // - load scalar value and broadcast - // - BUILD_VECTOR node - // - load from a constant pool. - // We check all variants here. - if (Op1.getOpcode() == X86ISD::VBROADCAST) { - if (auto *C = getTargetConstantFromNode(Op1.getOperand(0))) - return isSignBitValue(cast(C)); - - } else if (BuildVectorSDNode *BV = dyn_cast(Op1)) { - if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) - return isSignBitValue(CN->getConstantFPValue()); - - } else if (auto *C = getTargetConstantFromNode(Op1)) { - if (C->getType()->isVectorTy()) { - if (auto *SplatV = C->getSplatValue()) - return isSignBitValue(cast(SplatV)); - } else if (auto *FPConst = dyn_cast(C)) - return isSignBitValue(FPConst); - } - } - return false; + if (Subtarget.hasCMov()) + if (SDValue RV = combineIntegerAbs(N, DAG)) + return RV; + + if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) + return FPLogic; + + if (isFNEG(N)) + return combineFneg(N, DAG, Subtarget); + return SDValue(); } /// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes. @@ -30907,18 +30917,20 @@ SDValue B = N->getOperand(1); SDValue C = N->getOperand(2); - bool NegA = isFNEG(A.getNode()); - bool NegB = isFNEG(B.getNode()); - bool NegC = isFNEG(C.getNode()); + auto invertIfNegative = [](SDValue &V) { + if (SDValue NegVal = isFNEG(V.getNode())) { + V = NegVal; + return true; + } + return false; + }; + + bool NegA = invertIfNegative(A); + bool NegB = invertIfNegative(B); + bool NegC = invertIfNegative(C); // Negative multiplication when NegA xor NegB bool NegMul = (NegA != NegB); - if (NegA) - A = A.getOperand(0); - if (NegB) - B = B.getOperand(0); - if (NegC) - C = C.getOperand(0); unsigned NewOpcode; if (!NegMul) Index: llvm/trunk/test/CodeGen/X86/fma-fneg-combine.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/fma-fneg-combine.ll +++ llvm/trunk/test/CodeGen/X86/fma-fneg-combine.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512bw -mattr=+avx512vl -mattr=+avx512dq | FileCheck %s +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512bw -mattr=+avx512vl -mattr=+avx512dq | FileCheck %s --check-prefix=CHECK --check-prefix=SKX +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512f -mattr=+fma | FileCheck %s --check-prefix=CHECK --check-prefix=KNL ; This test checks combinations of FNEG and FMA intrinsics on AVX-512 target ; PR28892 @@ -88,11 +89,18 @@ } define <8 x float> @test8(<8 x float> %a, <8 x float> %b, <8 x float> %c) { -; CHECK-LABEL: test8: -; CHECK: # BB#0: # %entry -; CHECK-NEXT: vxorps {{.*}}(%rip){1to8}, %ymm2, %ymm2 -; CHECK-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 -; CHECK-NEXT: retq +; SKX-LABEL: test8: +; SKX: # BB#0: # %entry +; SKX-NEXT: vxorps {{.*}}(%rip){1to8}, %ymm2, %ymm2 +; SKX-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 +; SKX-NEXT: retq +; +; KNL-LABEL: test8: +; KNL: # BB#0: # %entry +; KNL-NEXT: vbroadcastss {{.*}}(%rip), %ymm3 +; KNL-NEXT: vxorps %ymm3, %ymm2, %ymm2 +; KNL-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 +; KNL-NEXT: retq entry: %sub.c = fsub <8 x float> , %c %0 = tail call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %sub.c) #2 @@ -115,22 +123,9 @@ declare <8 x double> @llvm.x86.avx512.mask.vfmadd.pd.512(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8, i32) -define <4 x double> @test10(<4 x double> %a, <4 x double> %b, <4 x double> %c) { +define <2 x double> @test10(<2 x double> %a, <2 x double> %b, <2 x double> %c) { ; CHECK-LABEL: test10: ; CHECK: # BB#0: # %entry -; CHECK-NEXT: vfnmsub213pd %ymm2, %ymm1, %ymm0 -; CHECK-NEXT: retq -entry: - %0 = tail call <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8 -1) #2 - %sub.i = fsub <4 x double> , %0 - ret <4 x double> %sub.i -} - -declare <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8) - -define <2 x double> @test11(<2 x double> %a, <2 x double> %b, <2 x double> %c) { -; CHECK-LABEL: test11: -; CHECK: # BB#0: # %entry ; CHECK-NEXT: vfnmsub213sd %xmm2, %xmm0, %xmm1 ; CHECK-NEXT: vmovaps %xmm1, %xmm0 ; CHECK-NEXT: retq