Index: ../lib/Target/X86/X86ISelLowering.cpp =================================================================== --- ../lib/Target/X86/X86ISelLowering.cpp +++ ../lib/Target/X86/X86ISelLowering.cpp @@ -4866,23 +4866,23 @@ return true; } -static const Constant *getTargetShuffleMaskConstant(SDValue MaskNode) { - MaskNode = peekThroughBitcasts(MaskNode); +static const Constant *getTargetConstantFromNode(SDValue Op) { + Op = peekThroughBitcasts(Op); - auto *MaskLoad = dyn_cast(MaskNode); - if (!MaskLoad) + auto *Load = dyn_cast(Op); + if (!Load) return nullptr; - SDValue Ptr = MaskLoad->getBasePtr(); + SDValue Ptr = Load->getBasePtr(); if (Ptr->getOpcode() == X86ISD::Wrapper || Ptr->getOpcode() == X86ISD::WrapperRIP) Ptr = Ptr->getOperand(0); - auto *MaskCP = dyn_cast(Ptr); - if (!MaskCP || MaskCP->isMachineConstantPoolEntry()) + auto *CNode = dyn_cast(Ptr); + if (!CNode || CNode->isMachineConstantPoolEntry()) return nullptr; - return dyn_cast(MaskCP->getConstVal()); + return dyn_cast(CNode->getConstVal()); } /// Calculates the shuffle mask corresponding to the target-specific opcode. @@ -4992,7 +4992,7 @@ DecodeVPERMILPMask(VT, RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMILPMask(C, MaskEltSize, Mask); break; } @@ -5006,7 +5006,7 @@ DecodePSHUFBMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodePSHUFBMask(C, Mask); break; } @@ -5055,7 +5055,7 @@ DecodeVPERMIL2PMask(VT, CtrlImm, RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMIL2PMask(C, CtrlImm, MaskEltSize, Mask); break; } @@ -5070,7 +5070,7 @@ DecodeVPPERMMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPPERMMask(C, Mask); break; } @@ -5087,7 +5087,7 @@ DecodeVPERMVMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMVMask(C, VT, Mask); break; } @@ -5099,7 +5099,7 @@ Ops.push_back(N->getOperand(0)); Ops.push_back(N->getOperand(2)); SDValue MaskNode = N->getOperand(1); - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMV3Mask(C, VT, Mask); break; } @@ -30358,6 +30358,18 @@ case X86ISD::FNMSUB: return DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0), Arg.getOperand(1), Arg.getOperand(2)); + case X86ISD::FMADD_RND: + return DAG.getNode(X86ISD::FNMSUB_RND, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + case X86ISD::FMSUB_RND: + return DAG.getNode(X86ISD::FNMADD_RND, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + case X86ISD::FNMADD_RND: + return DAG.getNode(X86ISD::FMSUB_RND, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + case X86ISD::FNMSUB_RND: + return DAG.getNode(X86ISD::FMADD_RND, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); } } return SDValue(); @@ -30386,6 +30398,45 @@ } return SDValue(); } + +/// Returns true if the node \p N is FNEG(x) or FXOR (x, 0x80000000). +bool isFNEG(const SDNode *N) { + if (N->getOpcode() == ISD::FNEG) + return true; + + if (N->getOpcode() == X86ISD::FXOR) { + unsigned EltBits = N->getSimpleValueType(0).getScalarSizeInBits(); + SDValue Op1 = N->getOperand(1); + + auto isSignBitValue = [&](const ConstantFP *C) { + return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits); + }; + + // There is more than one way to represent the same constant on + // the different X86 targets. The type of the node may also depend on size. + // - load scalar value and broadcast + // - BUILD_VECTOR node + // - load from a constant pool. + // We check all variants here. + if (Op1.getOpcode() == X86ISD::VBROADCAST) { + if (auto *C = getTargetConstantFromNode(Op1.getOperand(0))) + return isSignBitValue(cast(C)); + + } else if (BuildVectorSDNode *BV = dyn_cast(Op1)) { + if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) + return isSignBitValue(CN->getConstantFPValue()); + + } else if (auto *C = getTargetConstantFromNode(Op1)) { + if (C->getType()->isVectorTy()) { + if (auto *SplatV = C->getSplatValue()) + return isSignBitValue(cast(SplatV)); + } else if (auto *FPConst = dyn_cast(C)) + return isSignBitValue(FPConst); + } + } + return false; +} + /// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes. static SDValue combineFOr(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { @@ -30401,6 +30452,9 @@ if (C->getValueAPF().isPosZero()) return N->getOperand(0); + if (isFNEG(N)) + if (SDValue NewVal = combineFneg(N, DAG, Subtarget)) + return NewVal; return lowerX86FPLogicOp(N, DAG, Subtarget); } @@ -30810,9 +30864,9 @@ SDValue B = N->getOperand(1); SDValue C = N->getOperand(2); - bool NegA = (A.getOpcode() == ISD::FNEG); - bool NegB = (B.getOpcode() == ISD::FNEG); - bool NegC = (C.getOpcode() == ISD::FNEG); + bool NegA = isFNEG(A.getNode()); + bool NegB = isFNEG(B.getNode()); + bool NegC = isFNEG(C.getNode()); // Negative multiplication when NegA xor NegB bool NegMul = (NegA != NegB); @@ -30823,13 +30877,22 @@ if (NegC) C = C.getOperand(0); - unsigned Opcode; + unsigned NewOpcode; if (!NegMul) - Opcode = (!NegC) ? X86ISD::FMADD : X86ISD::FMSUB; + NewOpcode = (!NegC) ? X86ISD::FMADD : X86ISD::FMSUB; else - Opcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; + NewOpcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; - return DAG.getNode(Opcode, dl, VT, A, B, C); + if (N->getOpcode() == X86ISD::FMADD_RND) { + switch (NewOpcode) { + case X86ISD::FMADD: NewOpcode = X86ISD::FMADD_RND; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB_RND; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD_RND; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB_RND; break; + } + return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); + } + return DAG.getNode(NewOpcode, dl, VT, A, B, C); } static SDValue combineZext(SDNode *N, SelectionDAG &DAG, @@ -31559,6 +31622,8 @@ case X86ISD::VPERM2X128: case X86ISD::VZEXT_MOVL: case ISD::VECTOR_SHUFFLE: return combineShuffle(N, DAG, DCI,Subtarget); + case X86ISD::FMADD: + case X86ISD::FMADD_RND: case ISD::FMA: return combineFMA(N, DAG, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG); Index: ../test/CodeGen/X86/avx2-fma-fneg-combine.ll =================================================================== --- ../test/CodeGen/X86/avx2-fma-fneg-combine.ll +++ ../test/CodeGen/X86/avx2-fma-fneg-combine.ll @@ -0,0 +1,86 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx2 -mattr=+fma | FileCheck %s + +; This test checks combinations of FNEG and FMA intrinsics + +define <8 x float> @test1(<8 x float> %a, <8 x float> %b, <8 x float> %c) { +; CHECK-LABEL: test1: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %sub.i = fsub <8 x float> , %c + %0 = tail call <8 x float> @llvm.x86.fma.vfmadd.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %sub.i) #2 + ret <8 x float> %0 +} + +declare <8 x float> @llvm.x86.fma.vfmadd.ps.256(<8 x float>, <8 x float>, <8 x float>) + +define <4 x float> @test2(<4 x float> %a, <4 x float> %b, <4 x float> %c) { +; CHECK-LABEL: test2: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vfnmsub213ps %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = tail call <4 x float> @llvm.x86.fma.vfmadd.ps(<4 x float> %a, <4 x float> %b, <4 x float> %c) #2 + %sub.i = fsub <4 x float> , %0 + ret <4 x float> %sub.i +} + +declare <4 x float> @llvm.x86.fma.vfmadd.ps(<4 x float> %a, <4 x float> %b, <4 x float> %c) + +define <4 x float> @test3(<4 x float> %a, <4 x float> %b, <4 x float> %c) { +; CHECK-LABEL: test3: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vfnmadd213ss %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: vbroadcastss {{.*}}(%rip), %xmm1 +; CHECK-NEXT: vxorps %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = tail call <4 x float> @llvm.x86.fma.vfnmadd.ss(<4 x float> %a, <4 x float> %b, <4 x float> %c) #2 + %sub.i = fsub <4 x float> , %0 + ret <4 x float> %sub.i +} + +declare <4 x float> @llvm.x86.fma.vfnmadd.ss(<4 x float> %a, <4 x float> %b, <4 x float> %c) + +define <8 x float> @test4(<8 x float> %a, <8 x float> %b, <8 x float> %c) { +; CHECK-LABEL: test4: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vfnmadd213ps %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = tail call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %c) #2 + %sub.i = fsub <8 x float> , %0 + ret <8 x float> %sub.i +} + +define <8 x float> @test5(<8 x float> %a, <8 x float> %b, <8 x float> %c) { +; CHECK-LABEL: test5: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vbroadcastss {{.*}}(%rip), %ymm3 +; CHECK-NEXT: vxorps %ymm3, %ymm2, %ymm2 +; CHECK-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %sub.c = fsub <8 x float> , %c + %0 = tail call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %sub.c) #2 + ret <8 x float> %0 +} + +declare <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float>, <8 x float>, <8 x float>) + + +define <2 x double> @test6(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; CHECK-LABEL: test6: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vfnmsub213pd %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = tail call <2 x double> @llvm.x86.fma.vfmadd.pd(<2 x double> %a, <2 x double> %b, <2 x double> %c) #2 + %sub.i = fsub <2 x double> , %0 + ret <2 x double> %sub.i +} + +declare <2 x double> @llvm.x86.fma.vfmadd.pd(<2 x double> %a, <2 x double> %b, <2 x double> %c) + Index: ../test/CodeGen/X86/fma-fneg-combine.ll =================================================================== --- ../test/CodeGen/X86/fma-fneg-combine.ll +++ ../test/CodeGen/X86/fma-fneg-combine.ll @@ -7,8 +7,7 @@ define <16 x float> @test1(<16 x float> %a, <16 x float> %b, <16 x float> %c) { ; CHECK-LABEL: test1: ; CHECK: # BB#0: # %entry -; CHECK-NEXT: vxorps {{.*}}(%rip){1to16}, %zmm2, %zmm2 -; CHECK-NEXT: vfmadd213ps %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: vfmsub213ps %zmm2, %zmm1, %zmm0 ; CHECK-NEXT: retq entry: %sub.i = fsub <16 x float> , %c @@ -24,8 +23,7 @@ define <16 x float> @test2(<16 x float> %a, <16 x float> %b, <16 x float> %c) { ; CHECK-LABEL: test2: ; CHECK: # BB#0: # %entry -; CHECK-NEXT: vfmadd213ps %zmm2, %zmm1, %zmm0 -; CHECK-NEXT: vxorps {{.*}}(%rip){1to16}, %zmm0, %zmm0 +; CHECK-NEXT: vfnmsub213ps %zmm2, %zmm1, %zmm0 ; CHECK-NEXT: retq entry: %0 = tail call <16 x float> @llvm.x86.avx512.mask.vfmadd.ps.512(<16 x float> %a, <16 x float> %b, <16 x float> %c, i16 -1, i32 4) #2 @@ -36,8 +34,7 @@ define <16 x float> @test3(<16 x float> %a, <16 x float> %b, <16 x float> %c) { ; CHECK-LABEL: test3: ; CHECK: # BB#0: # %entry -; CHECK-NEXT: vfnmadd213ps %zmm2, %zmm1, %zmm0 -; CHECK-NEXT: vxorps {{.*}}(%rip){1to16}, %zmm0, %zmm0 +; CHECK-NEXT: vfmsub213ps %zmm2, %zmm1, %zmm0 ; CHECK-NEXT: retq entry: %0 = tail call <16 x float> @llvm.x86.avx512.mask.vfnmadd.ps.512(<16 x float> %a, <16 x float> %b, <16 x float> %c, i16 -1, i32 4) #2 @@ -48,8 +45,7 @@ define <16 x float> @test4(<16 x float> %a, <16 x float> %b, <16 x float> %c) { ; CHECK-LABEL: test4: ; CHECK: # BB#0: # %entry -; CHECK-NEXT: vfnmsub213ps %zmm2, %zmm1, %zmm0 -; CHECK-NEXT: vxorps {{.*}}(%rip){1to16}, %zmm0, %zmm0 +; CHECK-NEXT: vfmadd213ps %zmm2, %zmm1, %zmm0 ; CHECK-NEXT: retq entry: %0 = tail call <16 x float> @llvm.x86.avx512.mask.vfnmsub.ps.512(<16 x float> %a, <16 x float> %b, <16 x float> %c, i16 -1, i32 4) #2 @@ -60,8 +56,7 @@ define <16 x float> @test5(<16 x float> %a, <16 x float> %b, <16 x float> %c) { ; CHECK-LABEL: test5: ; CHECK: # BB#0: # %entry -; CHECK-NEXT: vxorps {{.*}}(%rip){1to16}, %zmm2, %zmm2 -; CHECK-NEXT: vfmadd213ps {ru-sae}, %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: vfmsub213ps {ru-sae}, %zmm2, %zmm1, %zmm0 ; CHECK-NEXT: retq entry: %sub.i = fsub <16 x float> , %c @@ -72,8 +67,7 @@ define <16 x float> @test6(<16 x float> %a, <16 x float> %b, <16 x float> %c) { ; CHECK-LABEL: test6: ; CHECK: # BB#0: # %entry -; CHECK-NEXT: vfnmsub213ps {ru-sae}, %zmm2, %zmm1, %zmm0 -; CHECK-NEXT: vxorps {{.*}}(%rip){1to16}, %zmm0, %zmm0 +; CHECK-NEXT: vfmadd213ps {ru-sae}, %zmm2, %zmm1, %zmm0 ; CHECK-NEXT: retq entry: %0 = tail call <16 x float> @llvm.x86.avx512.mask.vfnmsub.ps.512(<16 x float> %a, <16 x float> %b, <16 x float> %c, i16 -1, i32 2) #2 @@ -85,8 +79,7 @@ define <8 x float> @test7(<8 x float> %a, <8 x float> %b, <8 x float> %c) { ; CHECK-LABEL: test7: ; CHECK: # BB#0: # %entry -; CHECK-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0 -; CHECK-NEXT: vxorps {{.*}}(%rip){1to8}, %ymm0, %ymm0 +; CHECK-NEXT: vfnmadd213ps %ymm2, %ymm1, %ymm0 ; CHECK-NEXT: retq entry: %0 = tail call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %c) #2 @@ -108,3 +101,44 @@ declare <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float>, <8 x float>, <8 x float>) + +define <8 x double> @test9(<8 x double> %a, <8 x double> %b, <8 x double> %c) { +; CHECK-LABEL: test9: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vfnmsub213pd %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq +entry: + %0 = tail call <8 x double> @llvm.x86.avx512.mask.vfmadd.pd.512(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8 -1, i32 4) #2 + %sub.i = fsub <8 x double> , %0 + ret <8 x double> %sub.i +} + +declare <8 x double> @llvm.x86.avx512.mask.vfmadd.pd.512(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8, i32) + +define <4 x double> @test10(<4 x double> %a, <4 x double> %b, <4 x double> %c) { +; CHECK-LABEL: test10: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vfnmsub213pd %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq +entry: + %0 = tail call <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8 -1) #2 + %sub.i = fsub <4 x double> , %0 + ret <4 x double> %sub.i +} + +declare <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8) + +define <2 x double> @test11(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +; CHECK-LABEL: test11: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vfnmsub213sd %xmm2, %xmm0, %xmm1 +; CHECK-NEXT: vmovaps %xmm1, %xmm0 +; CHECK-NEXT: retq +entry: + %0 = tail call <2 x double> @llvm.x86.avx512.mask.vfmadd.sd(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 -1, i32 4) #2 + %sub.i = fsub <2 x double> , %0 + ret <2 x double> %sub.i +} + +declare <2 x double> @llvm.x86.avx512.mask.vfmadd.sd(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8, i32) +