Index: ../lib/Target/X86/X86ISelLowering.cpp =================================================================== --- ../lib/Target/X86/X86ISelLowering.cpp +++ ../lib/Target/X86/X86ISelLowering.cpp @@ -29190,28 +29190,6 @@ return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones); } -static SDValue combineXor(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { - if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) - return Cmp; - - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - - if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG)) - return RV; - - if (Subtarget.hasCMov()) - if (SDValue RV = combineIntegerAbs(N, DAG)) - return RV; - - if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) - return FPLogic; - - return SDValue(); -} - /// This function detects the AVG pattern between vectors of unsigned i8/i16, /// which is c = (a + b + 1) / 2, and replace this operation with the efficient /// X86ISD::AVG instruction. @@ -30320,12 +30298,70 @@ return combineVectorTruncation(N, DAG, Subtarget); } +/// Returns true if the node \p N is FNEG(x) or FXOR (x, 0x80000000). +/// \p Op0 - a pointer to the negated operand. +/// +/// AVX512F does not has FXOR, so FNEG is lowered as +/// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))). +/// In this case we go though all bitcasts. +bool isFNEG(SDNode *N, SDValue* Op0 = nullptr) { + if (N->getOpcode() == ISD::FNEG) { + if (Op0) + *Op0 = N->getOperand(0); + return true; + } + + SDValue Op = peekThroughBitcasts(SDValue(N, 0)); + if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR) + return false; + + SDValue Op1 = peekThroughBitcasts(Op.getOperand(1)); + if (!Op1.getValueType().isFloatingPoint()) + return false; + + if (Op0) + *Op0 = peekThroughBitcasts(Op.getOperand(0)); + + unsigned EltBits = Op1.getValueType().getScalarSizeInBits(); + auto isSignBitValue = [&](const ConstantFP *C) { + return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits); + }; + + // There is more than one way to represent the same constant on + // the different X86 targets. The type of the node may also depend on size. + // - load scalar value and broadcast + // - BUILD_VECTOR node + // - load from a constant pool. + // We check all variants here. + if (Op1.getOpcode() == X86ISD::VBROADCAST) { + if (auto *C = getTargetConstantFromNode(Op1.getOperand(0))) + return isSignBitValue(cast(C)); + + } else if (BuildVectorSDNode *BV = dyn_cast(Op1)) { + if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) + return isSignBitValue(CN->getConstantFPValue()); + + } else if (auto *C = getTargetConstantFromNode(Op1)) { + if (C->getType()->isVectorTy()) { + if (auto *SplatV = C->getSplatValue()) + return isSignBitValue(cast(SplatV)); + } else if (auto *FPConst = dyn_cast(C)) + return isSignBitValue(FPConst); + } + return false; +} + /// Do target-specific dag combines on floating point negations. static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { - EVT VT = N->getValueType(0); + EVT OrigVT = N->getValueType(0); + SDValue Arg; + bool IsFNEGNode = isFNEG(N, &Arg); + if (!IsFNEGNode) + llvm_unreachable("Expected FNEG node"); + + EVT VT = Arg.getValueType(); EVT SVT = VT.getScalarType(); - SDValue Arg = N->getOperand(0); SDLoc DL(N); // Let legalize expand this if it isn't a legal type yet. @@ -30338,8 +30374,9 @@ if (Arg.getOpcode() == ISD::FMUL && (SVT == MVT::f32 || SVT == MVT::f64) && Arg->getFlags()->hasNoSignedZeros() && Subtarget.hasAnyFMA()) { SDValue Zero = DAG.getConstantFP(0.0, DL, VT); - return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Zero); + SDValue NewNode = DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Zero); + return DAG.getBitcast(OrigVT, NewNode); } // If we're negating a FMA node, then we can adjust the @@ -30347,29 +30384,41 @@ if (Arg.hasOneUse()) { switch (Arg.getOpcode()) { case X86ISD::FMADD: - return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); + return DAG.getBitcast(OrigVT, + DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2))); case X86ISD::FMSUB: - return DAG.getNode(X86ISD::FNMADD, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); + return DAG.getBitcast(OrigVT, + DAG.getNode(X86ISD::FNMADD, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2))); case X86ISD::FNMADD: - return DAG.getNode(X86ISD::FMSUB, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); + return DAG.getBitcast(OrigVT, + DAG.getNode(X86ISD::FMSUB, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2))); case X86ISD::FNMSUB: - return DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2)); + return DAG.getBitcast(OrigVT, + DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2))); case X86ISD::FMADD_RND: - return DAG.getNode(X86ISD::FNMSUB_RND, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + return DAG.getBitcast(OrigVT, + DAG.getNode(X86ISD::FNMSUB_RND, DL, VT, + Arg.getOperand(0), Arg.getOperand(1), + Arg.getOperand(2), Arg.getOperand(3))); case X86ISD::FMSUB_RND: - return DAG.getNode(X86ISD::FNMADD_RND, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + return DAG.getBitcast(OrigVT, + DAG.getNode(X86ISD::FNMADD_RND, DL, VT, + Arg.getOperand(0), Arg.getOperand(1), + Arg.getOperand(2), Arg.getOperand(3))); case X86ISD::FNMADD_RND: - return DAG.getNode(X86ISD::FMSUB_RND, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + return DAG.getBitcast(OrigVT, + DAG.getNode(X86ISD::FMSUB_RND, DL, VT, + Arg.getOperand(0), Arg.getOperand(1), + Arg.getOperand(2), Arg.getOperand(3))); case X86ISD::FNMSUB_RND: - return DAG.getNode(X86ISD::FMADD_RND, DL, VT, Arg.getOperand(0), - Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + return DAG.getBitcast(OrigVT, + DAG.getNode(X86ISD::FMADD_RND, DL, VT, + Arg.getOperand(0), Arg.getOperand(1), + Arg.getOperand(2), Arg.getOperand(3))); } } return SDValue(); @@ -30399,42 +30448,28 @@ return SDValue(); } -/// Returns true if the node \p N is FNEG(x) or FXOR (x, 0x80000000). -bool isFNEG(const SDNode *N) { - if (N->getOpcode() == ISD::FNEG) - return true; +static SDValue combineXor(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) + return Cmp; - if (N->getOpcode() == X86ISD::FXOR) { - unsigned EltBits = N->getSimpleValueType(0).getScalarSizeInBits(); - SDValue Op1 = N->getOperand(1); + if (DCI.isBeforeLegalizeOps()) + return SDValue(); - auto isSignBitValue = [&](const ConstantFP *C) { - return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits); - }; + if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG)) + return RV; - // There is more than one way to represent the same constant on - // the different X86 targets. The type of the node may also depend on size. - // - load scalar value and broadcast - // - BUILD_VECTOR node - // - load from a constant pool. - // We check all variants here. - if (Op1.getOpcode() == X86ISD::VBROADCAST) { - if (auto *C = getTargetConstantFromNode(Op1.getOperand(0))) - return isSignBitValue(cast(C)); - - } else if (BuildVectorSDNode *BV = dyn_cast(Op1)) { - if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) - return isSignBitValue(CN->getConstantFPValue()); - - } else if (auto *C = getTargetConstantFromNode(Op1)) { - if (C->getType()->isVectorTy()) { - if (auto *SplatV = C->getSplatValue()) - return isSignBitValue(cast(SplatV)); - } else if (auto *FPConst = dyn_cast(C)) - return isSignBitValue(FPConst); - } - } - return false; + if (Subtarget.hasCMov()) + if (SDValue RV = combineIntegerAbs(N, DAG)) + return RV; + + if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) + return FPLogic; + + if (isFNEG(N)) + return combineFneg(N, DAG, Subtarget); + return SDValue(); } /// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes. @@ -30864,18 +30899,12 @@ SDValue B = N->getOperand(1); SDValue C = N->getOperand(2); - bool NegA = isFNEG(A.getNode()); - bool NegB = isFNEG(B.getNode()); - bool NegC = isFNEG(C.getNode()); + bool NegA = isFNEG(A.getNode(), &A); + bool NegB = isFNEG(B.getNode(), &B); + bool NegC = isFNEG(C.getNode(), &C); // Negative multiplication when NegA xor NegB bool NegMul = (NegA != NegB); - if (NegA) - A = A.getOperand(0); - if (NegB) - B = B.getOperand(0); - if (NegC) - C = C.getOperand(0); unsigned NewOpcode; if (!NegMul) Index: ../test/CodeGen/X86/fma-fneg-combine.ll =================================================================== --- ../test/CodeGen/X86/fma-fneg-combine.ll +++ ../test/CodeGen/X86/fma-fneg-combine.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512bw -mattr=+avx512vl -mattr=+avx512dq | FileCheck %s +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512bw -mattr=+avx512vl -mattr=+avx512dq | FileCheck %s --check-prefix=CHECK --check-prefix=SKX +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512f -mattr=+fma | FileCheck %s --check-prefix=CHECK --check-prefix=KNL ; This test checks combinations of FNEG and FMA intrinsics on AVX-512 target ; PR28892 @@ -88,11 +89,18 @@ } define <8 x float> @test8(<8 x float> %a, <8 x float> %b, <8 x float> %c) { -; CHECK-LABEL: test8: -; CHECK: # BB#0: # %entry -; CHECK-NEXT: vxorps {{.*}}(%rip){1to8}, %ymm2, %ymm2 -; CHECK-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 -; CHECK-NEXT: retq +; SKX-LABEL: test8: +; SKX: # BB#0: # %entry +; SKX-NEXT: vxorps {{.*}}(%rip){1to8}, %ymm2, %ymm2 +; SKX-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 +; SKX-NEXT: retq +; +; KNL-LABEL: test8: +; KNL: # BB#0: # %entry +; KNL-NEXT: vbroadcastss {{.*}}(%rip), %ymm3 +; KNL-NEXT: vxorps %ymm3, %ymm2, %ymm2 +; KNL-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 +; KNL-NEXT: retq entry: %sub.c = fsub <8 x float> , %c %0 = tail call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %sub.c) #2 @@ -115,25 +123,18 @@ declare <8 x double> @llvm.x86.avx512.mask.vfmadd.pd.512(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8, i32) -define <4 x double> @test10(<4 x double> %a, <4 x double> %b, <4 x double> %c) { -; CHECK-LABEL: test10: -; CHECK: # BB#0: # %entry -; CHECK-NEXT: vfnmsub213pd %ymm2, %ymm1, %ymm0 -; CHECK-NEXT: retq -entry: - %0 = tail call <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8 -1) #2 - %sub.i = fsub <4 x double> , %0 - ret <4 x double> %sub.i -} - -declare <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8) - -define <2 x double> @test11(<2 x double> %a, <2 x double> %b, <2 x double> %c) { -; CHECK-LABEL: test11: -; CHECK: # BB#0: # %entry -; CHECK-NEXT: vfnmsub213sd %xmm2, %xmm0, %xmm1 -; CHECK-NEXT: vmovaps %xmm1, %xmm0 -; CHECK-NEXT: retq +define <2 x double> @test10(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; SKX-LABEL: test10: +; SKX: # BB#0: # %entry +; SKX-NEXT: vfnmsub213sd %xmm2, %xmm0, %xmm1 +; SKX-NEXT: vmovaps %xmm1, %xmm0 +; SKX-NEXT: retq +; +; KNL-LABEL: test10: +; KNL: # BB#0: # %entry +; KNL-NEXT: vfnmsub213sd %xmm2, %xmm0, %xmm1 +; KNL-NEXT: vmovaps %zmm1, %zmm0 +; KNL-NEXT: retq entry: %0 = tail call <2 x double> @llvm.x86.avx512.mask.vfmadd.sd(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 -1, i32 4) #2 %sub.i = fsub <2 x double> , %0