Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -5633,6 +5633,24 @@ } return CastBitData(UndefSrcElts, SrcEltBits); } + if (ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) { + unsigned SrcEltSizeInBits = VT.getScalarSizeInBits(); + unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits; + + APInt UndefSrcElts(NumSrcElts, 0); + SmallVector SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0)); + for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) { + const SDValue &Src = Op.getOperand(i); + if (Src.isUndef()) { + UndefSrcElts.setBit(i); + continue; + } + auto *Cst = cast(Src); + APInt RawBits = Cst->getValueAPF().bitcastToAPInt(); + SrcEltBits[i] = RawBits.zextOrTrunc(SrcEltSizeInBits); + } + return CastBitData(UndefSrcElts, SrcEltBits); + } // Extract constant bits from constant pool vector. if (auto *Cst = getTargetConstantFromNode(Op)) { @@ -36971,29 +36989,72 @@ /// Returns the negated value if the node \p N flips sign of FP value. /// -/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000). +/// FP-negation node may have different forms: FNEG(x), FXOR (x, 0x80000000) +/// or FSUB(0, x) /// AVX512F does not have FXOR, so FNEG is lowered as /// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))). /// In this case we go though all bitcasts. -static SDValue isFNEG(SDNode *N) { +/// This also recognizes splat of a negated value and returns the splat of that +/// value. +static SDValue isFNEG(SelectionDAG &DAG, SDNode *N) { if (N->getOpcode() == ISD::FNEG) return N->getOperand(0); SDValue Op = peekThroughBitcasts(SDValue(N, 0)); - if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR) + auto VT = Op->getValueType(0); + if (auto SVOp = dyn_cast(Op.getNode())) { + // For a VECTOR_SHUFFLE(VEC1, VEC2), if the VEC2 is undef, then the negate + // of this is VECTOR_SHUFFLE(-VEC1, UNDEF). The mask can be anything here. + if (!SVOp->getOperand(1).isUndef()) + return SDValue(); + if (SDValue NegOp0 = isFNEG(DAG, SVOp->getOperand(0).getNode())) + return DAG.getVectorShuffle(VT, SDLoc(SVOp), NegOp0, DAG.getUNDEF(VT), + SVOp->getMask()); + return SDValue(); + } + unsigned Opc = Op.getOpcode(); + if (Opc == ISD::INSERT_VECTOR_ELT) { + // Negate of INSERT_VECTOR_ELT(UNDEF, V, INDEX) is INSERT_VECTOR_ELT(UNDEF, + // -V, INDEX). + SDValue InsVector = Op.getOperand(0); + SDValue InsVal = Op.getOperand(1); + if (!InsVector.isUndef()) + return SDValue(); + if (SDValue NegInsVal = isFNEG(DAG, InsVal.getNode())) + return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(Op), VT, InsVector, + NegInsVal, Op.getOperand(2)); + return SDValue(); + } + + if (Opc != X86ISD::FXOR && Opc != ISD::XOR && Opc != ISD::FSUB) return SDValue(); SDValue Op1 = peekThroughBitcasts(Op.getOperand(1)); if (!Op1.getValueType().isFloatingPoint()) return SDValue(); - // Extract constant bits and see if they are all sign bit masks. + SDValue Op0 = peekThroughBitcasts(Op.getOperand(0)); + + // For XOR and FXOR, we want to check if constant bits of Op1 are sign bit + // masks. For FSUB, we have to check if constant bits of Op0 are sign bit + // masks and hence we swap the operands. + if (Opc == ISD::FSUB) + std::swap(Op0, Op1); + APInt UndefElts; SmallVector EltBits; + // Extract constant bits and see if they are all sign bit masks. Ignore the + // undef elements. if (getTargetConstantBitsFromNode(Op1, Op1.getScalarValueSizeInBits(), - UndefElts, EltBits, false, false)) - if (llvm::all_of(EltBits, [](APInt &I) { return I.isSignMask(); })) - return peekThroughBitcasts(Op.getOperand(0)); + UndefElts, EltBits, + /* AllowWholeUndefs */ true, + /* AllowPartialUndefs */ false)) { + for (unsigned I = 0, E = EltBits.size(); I < E; I++) + if (!UndefElts[I] && !EltBits[I].isSignMask()) + return SDValue(); + + return peekThroughBitcasts(Op0); + } return SDValue(); } @@ -37002,8 +37063,9 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { EVT OrigVT = N->getValueType(0); - SDValue Arg = isFNEG(N); - assert(Arg.getNode() && "N is expected to be an FNEG node"); + SDValue Arg = isFNEG(DAG, N); + if (!Arg) + return SDValue(); EVT VT = Arg.getValueType(); EVT SVT = VT.getScalarType(); @@ -37118,9 +37180,7 @@ if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget)) return FPLogic; - if (isFNEG(N)) - return combineFneg(N, DAG, Subtarget); - return SDValue(); + return combineFneg(N, DAG, Subtarget); } static SDValue combineBEXTR(SDNode *N, SelectionDAG &DAG, @@ -37253,9 +37313,8 @@ if (isNullFPScalarOrVectorConst(N->getOperand(1))) return N->getOperand(0); - if (isFNEG(N)) - if (SDValue NewVal = combineFneg(N, DAG, Subtarget)) - return NewVal; + if (SDValue NewVal = combineFneg(N, DAG, Subtarget)) + return NewVal; return lowerX86FPLogicOp(N, DAG, Subtarget); } @@ -37940,7 +37999,7 @@ SDValue C = N->getOperand(2); auto invertIfNegative = [&DAG](SDValue &V) { - if (SDValue NegVal = isFNEG(V.getNode())) { + if (SDValue NegVal = isFNEG(DAG, V.getNode())) { V = DAG.getBitcast(V.getValueType(), NegVal); return true; } @@ -37948,7 +38007,7 @@ // new extract from the FNEG input. if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT && isNullConstant(V.getOperand(1))) { - if (SDValue NegVal = isFNEG(V.getOperand(0).getNode())) { + if (SDValue NegVal = isFNEG(DAG, V.getOperand(0).getNode())) { NegVal = DAG.getBitcast(V.getOperand(0).getValueType(), NegVal); V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(), NegVal, V.getOperand(1)); @@ -37981,7 +38040,7 @@ SDLoc dl(N); EVT VT = N->getValueType(0); - SDValue NegVal = isFNEG(N->getOperand(2).getNode()); + SDValue NegVal = isFNEG(DAG, N->getOperand(2).getNode()); if (!NegVal) return SDValue(); Index: llvm/trunk/test/CodeGen/X86/avx2-fma-fneg-combine.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/avx2-fma-fneg-combine.ll +++ llvm/trunk/test/CodeGen/X86/avx2-fma-fneg-combine.ll @@ -118,20 +118,14 @@ define <8 x float> @test7(float %a, <8 x float> %b, <8 x float> %c) { ; X32-LABEL: test7: ; X32: # %bb.0: # %entry -; X32-NEXT: vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero -; X32-NEXT: vmovss {{.*#+}} xmm3 = mem[0],zero,zero,zero -; X32-NEXT: vsubps %ymm2, %ymm3, %ymm2 -; X32-NEXT: vbroadcastss %xmm2, %ymm2 -; X32-NEXT: vfmadd213ps {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1 +; X32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %ymm2 +; X32-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm2 * ymm0) + ymm1 ; X32-NEXT: retl ; ; X64-LABEL: test7: ; X64: # %bb.0: # %entry -; X64-NEXT: # kill: def $xmm0 killed $xmm0 def $ymm0 -; X64-NEXT: vmovss {{.*#+}} xmm3 = mem[0],zero,zero,zero -; X64-NEXT: vsubps %ymm0, %ymm3, %ymm0 ; X64-NEXT: vbroadcastss %xmm0, %ymm0 -; X64-NEXT: vfmadd213ps {{.*#+}} ymm0 = (ymm1 * ymm0) + ymm2 +; X64-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm2 ; X64-NEXT: retq entry: %0 = insertelement <8 x float> undef, float %a, i32 0 @@ -145,19 +139,14 @@ define <8 x float> @test8(float %a, <8 x float> %b, <8 x float> %c) { ; X32-LABEL: test8: ; X32: # %bb.0: # %entry -; X32-NEXT: vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero -; X32-NEXT: vbroadcastss {{.*#+}} xmm3 = [-0,-0,-0,-0] -; X32-NEXT: vxorps %xmm3, %xmm2, %xmm2 -; X32-NEXT: vbroadcastss %xmm2, %ymm2 -; X32-NEXT: vfmadd213ps {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1 +; X32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %ymm2 +; X32-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm2 * ymm0) + ymm1 ; X32-NEXT: retl ; ; X64-LABEL: test8: ; X64: # %bb.0: # %entry -; X64-NEXT: vbroadcastss {{.*#+}} xmm3 = [-0,-0,-0,-0] -; X64-NEXT: vxorps %xmm3, %xmm0, %xmm0 ; X64-NEXT: vbroadcastss %xmm0, %ymm0 -; X64-NEXT: vfmadd213ps {{.*#+}} ymm0 = (ymm1 * ymm0) + ymm2 +; X64-NEXT: vfnmadd213ps {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm2 ; X64-NEXT: retq entry: %0 = fsub float -0.0, %a