diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3543,9 +3543,39 @@ /// If getNegatibleCost returns Neutral/Cheaper, return the newly negated /// expression. - virtual SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, - bool LegalOperations, bool ForCodeSize, - unsigned Depth = 0) const; + virtual SDValue negateExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps, + bool OptForSize, unsigned Depth = 0) const; + + /// Return the newly negated expression if the cost is not expensive and + /// set the cost in \p Cost to indicate that if it is cheaper or neutral to + /// do the negation. + SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps, + bool OptForSize, NegatibleCost &Cost, + unsigned Depth = 0) const { + Cost = getNegatibleCost(Op, DAG, LegalOps, OptForSize, Depth); + if (Cost != NegatibleCost::Expensive) + return negateExpression(Op, DAG, LegalOps, OptForSize, Depth); + return SDValue(); + } + + /// This is the helper function to return the newly negated expression only + /// when the cost is cheaper. + SDValue getCheaperNegatedExpression(SDValue Op, SelectionDAG &DAG, + bool LegalOps, bool OptForSize, + unsigned Depth = 0) const { + if (getNegatibleCost(Op, DAG, LegalOps, OptForSize, Depth) == + NegatibleCost::Cheaper) + return negateExpression(Op, DAG, LegalOps, OptForSize, Depth); + return SDValue(); + } + + /// This is the helper function to return the newly negated expression if + /// the cost is not expensive. + SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps, + bool OptForSize, unsigned Depth = 0) const { + NegatibleCost Cost = NegatibleCost::Expensive; + return getNegatedExpression(Op, DAG, LegalOps, OptForSize, Cost, Depth); + } //===--------------------------------------------------------------------===// // Lowering methods - These methods must be implemented by targets so that diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -527,7 +527,6 @@ bool isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS, SDValue &CC, bool MatchStrict = false) const; bool isOneUseSetCC(SDValue N) const; - bool isCheaperToUseNegatedFPOps(SDValue X, SDValue Y); SDValue SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, unsigned HiOp); @@ -12392,20 +12391,16 @@ return NewSel; // fold (fadd A, (fneg B)) -> (fsub A, B) - if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && - TLI.getNegatibleCost(N1, DAG, LegalOperations, ForCodeSize) == - TargetLowering::NegatibleCost::Cheaper) - return DAG.getNode( - ISD::FSUB, DL, VT, N0, - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags); + if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) + if (SDValue NegN1 = TLI.getCheaperNegatedExpression( + N1, DAG, LegalOperations, ForCodeSize)) + return DAG.getNode(ISD::FSUB, DL, VT, N0, NegN1, Flags); // fold (fadd (fneg A), B) -> (fsub B, A) - if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && - TLI.getNegatibleCost(N0, DAG, LegalOperations, ForCodeSize) == - TargetLowering::NegatibleCost::Cheaper) - return DAG.getNode( - ISD::FSUB, DL, VT, N1, - TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize), Flags); + if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) + if (SDValue NegN0 = TLI.getCheaperNegatedExpression( + N0, DAG, LegalOperations, ForCodeSize)) + return DAG.getNode(ISD::FSUB, DL, VT, N1, NegN0, Flags); auto isFMulNegTwo = [](SDValue FMul) { if (!FMul.hasOneUse() || FMul.getOpcode() != ISD::FMUL) @@ -12587,9 +12582,9 @@ if (N0CFP && N0CFP->isZero()) { if (N0CFP->isNegative() || (Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros())) { - if (TLI.getNegatibleCost(N1, DAG, LegalOperations, ForCodeSize) != - TargetLowering::NegatibleCost::Expensive) - return TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + if (SDValue NegN1 = + TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize)) + return NegN1; if (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT)) return DAG.getNode(ISD::FNEG, DL, VT, N1, Flags); } @@ -12607,11 +12602,9 @@ } // fold (fsub A, (fneg B)) -> (fadd A, B) - if (TLI.getNegatibleCost(N1, DAG, LegalOperations, ForCodeSize) != - TargetLowering::NegatibleCost::Expensive) - return DAG.getNode( - ISD::FADD, DL, VT, N0, - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize), Flags); + if (SDValue NegN1 = + TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize)) + return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1, Flags); // FSUB -> FMA combines: if (SDValue Fused = visitFSUBForFMACombine(N)) { @@ -12622,25 +12615,6 @@ return SDValue(); } -/// Return true if both inputs are at least as cheap in negated form and at -/// least one input is strictly cheaper in negated form. -bool DAGCombiner::isCheaperToUseNegatedFPOps(SDValue X, SDValue Y) { - TargetLowering::NegatibleCost LHSNeg = - TLI.getNegatibleCost(X, DAG, LegalOperations, ForCodeSize); - if (TargetLowering::NegatibleCost::Expensive == LHSNeg) - return false; - - TargetLowering::NegatibleCost RHSNeg = - TLI.getNegatibleCost(Y, DAG, LegalOperations, ForCodeSize); - if (TargetLowering::NegatibleCost::Expensive == RHSNeg) - return false; - - // Both negated operands are at least as cheap as their counterparts. - // Check to see if at least one is cheaper negated. - return (TargetLowering::NegatibleCost::Cheaper == LHSNeg || - TargetLowering::NegatibleCost::Cheaper == RHSNeg); -} - SDValue DAGCombiner::visitFMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -12715,13 +12689,16 @@ return DAG.getNode(ISD::FNEG, DL, VT, N0); // -N0 * -N1 --> N0 * N1 - if (isCheaperToUseNegatedFPOps(N0, N1)) { - SDValue NegN0 = - TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); - SDValue NegN1 = - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + SDValue NegN0 = + TLI.getCheaperNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); + SDValue NegN1 = + TLI.getCheaperNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + if (NegN0 && !NegN1) + NegN1 = TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + else if (NegN1 && !NegN0) + NegN0 = TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); + if (NegN0 && NegN1) return DAG.getNode(ISD::FMUL, DL, VT, NegN0, NegN1, Flags); - } // fold (fmul X, (select (fcmp X > 0.0), -1.0, 1.0)) -> (fneg (fabs X)) // fold (fmul X, (select (fcmp X > 0.0), 1.0, -1.0)) -> (fabs X) @@ -12800,13 +12777,16 @@ } // (-N0 * -N1) + N2 --> (N0 * N1) + N2 - if (isCheaperToUseNegatedFPOps(N0, N1)) { - SDValue NegN0 = - TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); - SDValue NegN1 = - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + SDValue NegN0 = + TLI.getCheaperNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); + SDValue NegN1 = + TLI.getCheaperNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + if (NegN0 && !NegN1) + NegN1 = TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + else if (NegN1 && !NegN0) + NegN0 = TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); + if (NegN0 && NegN1) return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2, Flags); - } if (UnsafeFPMath) { if (N0CFP && N0CFP->isZero()) @@ -12892,13 +12872,10 @@ // fold ((fma (fneg X), Y, (fneg Z)) -> fneg (fma X, Y, Z)) // fold ((fma X, (fneg Y), (fneg Z)) -> fneg (fma X, Y, Z)) - if (!TLI.isFNegFree(VT) && - TLI.getNegatibleCost(SDValue(N, 0), DAG, LegalOperations, ForCodeSize) == - TargetLowering::NegatibleCost::Cheaper) - return DAG.getNode(ISD::FNEG, DL, VT, - TLI.getNegatedExpression(SDValue(N, 0), DAG, - LegalOperations, ForCodeSize), - Flags); + if (!TLI.isFNegFree(VT)) + if (SDValue Neg = TLI.getCheaperNegatedExpression( + SDValue(N, 0), DAG, LegalOperations, ForCodeSize)) + return DAG.getNode(ISD::FNEG, DL, VT, Neg, Flags); return SDValue(); } @@ -13074,13 +13051,16 @@ } // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y) - if (isCheaperToUseNegatedFPOps(N0, N1)) { - SDValue Neg0 = - TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); - SDValue Neg1 = - TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + SDValue Neg0 = + TLI.getCheaperNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); + SDValue Neg1 = + TLI.getCheaperNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + if (Neg0 && !Neg1) + Neg1 = TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize); + else if (Neg1 && !Neg0) + Neg0 = TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); + if (Neg0 && Neg1) return DAG.getNode(ISD::FDIV, SDLoc(N), VT, Neg0, Neg1, Flags); - } return SDValue(); } @@ -13626,9 +13606,9 @@ if (isConstantFPBuildVectorOrConstantFP(N0)) return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0); - if (TLI.getNegatibleCost(N0, DAG, LegalOperations, ForCodeSize) != - TargetLowering::NegatibleCost::Expensive) - return TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize); + if (SDValue NegN0 = + TLI.getNegatedExpression(N0, DAG, LegalOperations, ForCodeSize)) + return NegN0; // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0 // FIXME: This is duplicated in getNegatibleCost, but getNegatibleCost doesn't diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5571,8 +5571,7 @@ // to negate it even it has multiple uses. bool IsFreeConstant = Op.getOpcode() == ISD::ConstantFP && - !getNegatedExpression(Op, DAG, LegalOperations, ForCodeSize) - .use_empty(); + !negateExpression(Op, DAG, LegalOperations, ForCodeSize).use_empty(); if (!IsFreeExtend && !IsFreeConstant) return NegatibleCost::Expensive; @@ -5687,15 +5686,15 @@ return NegatibleCost::Expensive; } -SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, - bool LegalOps, bool OptForSize, - unsigned Depth) const { +SDValue TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG, + bool LegalOps, bool OptForSize, + unsigned Depth) const { // fneg is removable even if it has multiple uses. if (Op.getOpcode() == ISD::FNEG) return Op.getOperand(0); assert(Depth <= SelectionDAG::MaxRecursionDepth && - "getNegatedExpression doesn't match getNegatibleCost"); + "negateExpression doesn't match getNegatibleCost"); // Pre-increment recursion depth for use in recursive calls. ++Depth; @@ -5732,14 +5731,14 @@ // fold (fneg (fadd X, Y)) -> (fsub (fneg X), Y) NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth); if (CostX != NegatibleCost::Expensive) - return DAG.getNode( - ISD::FSUB, DL, VT, - getNegatedExpression(X, DAG, LegalOps, OptForSize, Depth), Y, Flags); + return DAG.getNode(ISD::FSUB, DL, VT, + negateExpression(X, DAG, LegalOps, OptForSize, Depth), + Y, Flags); // fold (fneg (fadd X, Y)) -> (fsub (fneg Y), X) - return DAG.getNode( - ISD::FSUB, DL, VT, - getNegatedExpression(Y, DAG, LegalOps, OptForSize, Depth), X, Flags); + return DAG.getNode(ISD::FSUB, DL, VT, + negateExpression(Y, DAG, LegalOps, OptForSize, Depth), X, + Flags); } case ISD::FSUB: { SDValue X = Op.getOperand(0), Y = Op.getOperand(1); @@ -5757,14 +5756,14 @@ // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth); if (CostX != NegatibleCost::Expensive) - return DAG.getNode( - Opcode, DL, VT, - getNegatedExpression(X, DAG, LegalOps, OptForSize, Depth), Y, Flags); + return DAG.getNode(Opcode, DL, VT, + negateExpression(X, DAG, LegalOps, OptForSize, Depth), + Y, Flags); // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y)) - return DAG.getNode( - Opcode, DL, VT, X, - getNegatedExpression(Y, DAG, LegalOps, OptForSize, Depth), Flags); + return DAG.getNode(Opcode, DL, VT, X, + negateExpression(Y, DAG, LegalOps, OptForSize, Depth), + Flags); } case ISD::FMA: case ISD::FMAD: { @@ -5773,30 +5772,30 @@ "Expected NSZ fp-flag"); SDValue X = Op.getOperand(0), Y = Op.getOperand(1), Z = Op.getOperand(2); - SDValue NegZ = getNegatedExpression(Z, DAG, LegalOps, OptForSize, Depth); + SDValue NegZ = negateExpression(Z, DAG, LegalOps, OptForSize, Depth); NegatibleCost CostX = getNegatibleCost(X, DAG, LegalOps, OptForSize, Depth); NegatibleCost CostY = getNegatibleCost(Y, DAG, LegalOps, OptForSize, Depth); if (CostX <= CostY) { // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z)) - SDValue NegX = getNegatedExpression(X, DAG, LegalOps, OptForSize, Depth); + SDValue NegX = negateExpression(X, DAG, LegalOps, OptForSize, Depth); return DAG.getNode(Opcode, DL, VT, NegX, Y, NegZ, Flags); } // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z)) - SDValue NegY = getNegatedExpression(Y, DAG, LegalOps, OptForSize, Depth); + SDValue NegY = negateExpression(Y, DAG, LegalOps, OptForSize, Depth); return DAG.getNode(Opcode, DL, VT, X, NegY, NegZ, Flags); } case ISD::FP_EXTEND: case ISD::FSIN: - return DAG.getNode(Opcode, DL, VT, - getNegatedExpression(Op.getOperand(0), DAG, LegalOps, - OptForSize, Depth)); + return DAG.getNode( + Opcode, DL, VT, + negateExpression(Op.getOperand(0), DAG, LegalOps, OptForSize, Depth)); case ISD::FP_ROUND: - return DAG.getNode(ISD::FP_ROUND, DL, VT, - getNegatedExpression(Op.getOperand(0), DAG, LegalOps, - OptForSize, Depth), - Op.getOperand(1)); + return DAG.getNode( + ISD::FP_ROUND, DL, VT, + negateExpression(Op.getOperand(0), DAG, LegalOps, OptForSize, Depth), + Op.getOperand(1)); } llvm_unreachable("Unknown code"); diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -821,9 +821,9 @@ /// If getNegatibleCost returns Neutral/Cheaper, return the newly negated /// expression. - SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, - bool LegalOperations, bool ForCodeSize, - unsigned Depth) const override; + SDValue negateExpression(SDValue Op, SelectionDAG &DAG, + bool LegalOperations, bool ForCodeSize, + unsigned Depth) const override; MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -44027,10 +44027,9 @@ bool CodeSize = DAG.getMachineFunction().getFunction().hasOptSize(); bool LegalOperations = !DCI.isBeforeLegalizeOps(); - if (TLI.getNegatibleCost(Arg, DAG, LegalOperations, CodeSize) != - TargetLowering::NegatibleCost::Expensive) - return DAG.getBitcast( - OrigVT, TLI.getNegatedExpression(Arg, DAG, LegalOperations, CodeSize)); + if (SDValue NegArg = + TLI.getNegatedExpression(Arg, DAG, LegalOperations, CodeSize)) + return DAG.getBitcast(OrigVT, NegArg); return SDValue(); } @@ -44082,10 +44081,10 @@ Depth); } -SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, - bool LegalOperations, - bool ForCodeSize, - unsigned Depth) const { +SDValue X86TargetLowering::negateExpression(SDValue Op, SelectionDAG &DAG, + bool LegalOperations, + bool ForCodeSize, + unsigned Depth) const { // fneg patterns are removable even if they have multiple uses. if (SDValue Arg = isFNEG(DAG, Op.getNode(), Depth)) return DAG.getBitcast(Op.getValueType(), Arg); @@ -44110,13 +44109,9 @@ // This is always negatible for free but we might be able to remove some // extra operand negations as well. SmallVector NewOps(Op.getNumOperands(), SDValue()); - for (int i = 0; i != 3; ++i) { - NegatibleCost V = getNegatibleCost(Op.getOperand(i), DAG, LegalOperations, - ForCodeSize, Depth + 1); - if (V == NegatibleCost::Cheaper) - NewOps[i] = getNegatedExpression(Op.getOperand(i), DAG, LegalOperations, - ForCodeSize, Depth + 1); - } + for (int i = 0; i != 3; ++i) + NewOps[i] = getCheaperNegatedExpression( + Op.getOperand(i), DAG, LegalOperations, ForCodeSize, Depth + 1); bool NegA = !!NewOps[0]; bool NegB = !!NewOps[1]; @@ -44131,13 +44126,13 @@ } case X86ISD::FRCP: return DAG.getNode(Opc, SDLoc(Op), VT, - getNegatedExpression(Op.getOperand(0), DAG, - LegalOperations, ForCodeSize, - Depth + 1)); + negateExpression(Op.getOperand(0), DAG, LegalOperations, + ForCodeSize, Depth + 1)); + break; } - return TargetLowering::getNegatedExpression(Op, DAG, LegalOperations, - ForCodeSize, Depth); + return TargetLowering::negateExpression(Op, DAG, LegalOperations, ForCodeSize, + Depth); } static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG, @@ -45030,9 +45025,9 @@ auto invertIfNegative = [&DAG, &TLI, &DCI](SDValue &V) { bool CodeSize = DAG.getMachineFunction().getFunction().hasOptSize(); bool LegalOperations = !DCI.isBeforeLegalizeOps(); - if (TLI.getNegatibleCost(V, DAG, LegalOperations, CodeSize) == - TargetLowering::NegatibleCost::Cheaper) { - V = TLI.getNegatedExpression(V, DAG, LegalOperations, CodeSize); + if (SDValue NegV = TLI.getCheaperNegatedExpression(V, DAG, LegalOperations, + CodeSize)) { + V = NegV; return true; } // Look through extract_vector_elts. If it comes from an FNEG, create a @@ -45040,12 +45035,10 @@ if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT && isNullConstant(V.getOperand(1))) { SDValue Vec = V.getOperand(0); - if (TLI.getNegatibleCost(Vec, DAG, LegalOperations, CodeSize) == - TargetLowering::NegatibleCost::Cheaper) { - SDValue NegVal = - TLI.getNegatedExpression(Vec, DAG, LegalOperations, CodeSize); + if (SDValue NegV = TLI.getCheaperNegatedExpression( + Vec, DAG, LegalOperations, CodeSize)) { V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(), - NegVal, V.getOperand(1)); + NegV, V.getOperand(1)); return true; } } @@ -45087,11 +45080,11 @@ bool LegalOperations = !DCI.isBeforeLegalizeOps(); SDValue N2 = N->getOperand(2); - if (TLI.getNegatibleCost(N2, DAG, LegalOperations, CodeSize) != - TargetLowering::NegatibleCost::Cheaper) - return SDValue(); - SDValue NegN2 = TLI.getNegatedExpression(N2, DAG, LegalOperations, CodeSize); + SDValue NegN2 = + TLI.getCheaperNegatedExpression(N2, DAG, LegalOperations, CodeSize); + if (!NegN2) + return SDValue(); unsigned NewOpcode = negateFMAOpcode(N->getOpcode(), false, true, false); if (N->getNumOperands() == 4)