Index: llvm/include/llvm/CodeGen/TargetLowering.h =================================================================== --- llvm/include/llvm/CodeGen/TargetLowering.h +++ llvm/include/llvm/CodeGen/TargetLowering.h @@ -3440,8 +3440,16 @@ /// Return 1 if we can compute the negated form of the specified expression /// for the same cost as the expression itself, or 2 if we can compute the /// negated form more cheaply than the expression itself. Else return 0. + /// + /// EnableUseCheck specifies whether the number of uses of a value affects + /// if negation is considered free. This is needed because the number of uses + /// of any value may change as we rewrite the expression. Therefore, when + /// called from getNegatedExpression(), we must explicitly set EnableUseCheck + /// to false to avoid getting a different answer than when called from other + /// contexts. virtual char isNegatibleForFree(SDValue Op, SelectionDAG &DAG, bool LegalOperations, bool ForCodeSize, + bool EnableUseCheck = true, unsigned Depth = 0) const; /// If isNegatibleForFree returns true, return the newly negated expression. Index: llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5413,18 +5413,21 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG, bool LegalOperations, bool ForCodeSize, + bool EnableUseCheck, unsigned Depth) const { // fneg is removable even if it has multiple uses. if (Op.getOpcode() == ISD::FNEG) return 2; - // Don't allow anything with multiple uses unless we know it is free. + // If the caller requires checking uses, don't allow anything with multiple + // uses unless we know it is free. EVT VT = Op.getValueType(); const SDNodeFlags Flags = Op->getFlags(); const TargetOptions &Options = DAG.getTarget().Options; - if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND && - isFPExtFree(VT, Op.getOperand(0).getValueType()))) - return 0; + if (EnableUseCheck) + if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND && + isFPExtFree(VT, Op.getOperand(0).getValueType()))) + return 0; // Don't recurse exponentially. if (Depth > SelectionDAG::MaxRecursionDepth) @@ -5468,11 +5471,11 @@ // fold (fneg (fadd A, B)) -> (fsub (fneg A), B) if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1)) + ForCodeSize, EnableUseCheck, Depth + 1)) return V; // fold (fneg (fadd A, B)) -> (fsub (fneg B), A) return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, EnableUseCheck, Depth + 1); case ISD::FSUB: // We can't turn -(A-B) into B-A when we honor signed zeros. if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros()) @@ -5485,7 +5488,7 @@ case ISD::FDIV: // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y)) if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1)) + ForCodeSize, EnableUseCheck, Depth + 1)) return V; // Ignore X * 2.0 because that is expected to be canonicalized to X + X. @@ -5494,7 +5497,7 @@ return 0; return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, EnableUseCheck, Depth + 1); case ISD::FMA: case ISD::FMAD: { @@ -5504,15 +5507,15 @@ // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z)) // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z)) char V2 = isNegatibleForFree(Op.getOperand(2), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, EnableUseCheck, Depth + 1); if (!V2) return 0; // One of Op0/Op1 must be cheaply negatible, then select the cheapest. char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, EnableUseCheck, Depth + 1); char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, EnableUseCheck, Depth + 1); char V01 = std::max(V0, V1); return V01 ? std::max(V01, V2) : 0; } @@ -5521,7 +5524,7 @@ case ISD::FP_ROUND: case ISD::FSIN: return isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, EnableUseCheck, Depth + 1); } return 0; @@ -5565,7 +5568,7 @@ // fold (fneg (fadd A, B)) -> (fsub (fneg A), B) if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize, - Depth + 1)) + false, Depth + 1)) return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), getNegatedExpression(Op.getOperand(0), DAG, LegalOperations, ForCodeSize, @@ -5592,7 +5595,7 @@ case ISD::FDIV: // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize, - Depth + 1)) + false, Depth + 1)) return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), getNegatedExpression(Op.getOperand(0), DAG, LegalOperations, ForCodeSize, @@ -5616,9 +5619,9 @@ ForCodeSize, Depth + 1); char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, false, Depth + 1); char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, false, Depth + 1); if (V0 >= V1) { // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z)) SDValue Neg0 = getNegatedExpression( Index: llvm/lib/Target/X86/X86ISelLowering.h =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.h +++ llvm/lib/Target/X86/X86ISelLowering.h @@ -806,7 +806,8 @@ /// for the same cost as the expression itself, or 2 if we can compute the /// negated form more cheaply than the expression itself. Else return 0. char isNegatibleForFree(SDValue Op, SelectionDAG &DAG, bool LegalOperations, - bool ForCodeSize, unsigned Depth) const override; + bool ForCodeSize, bool EnableUseCheck, + unsigned Depth) const override; /// If isNegatibleForFree returns true, return the newly negated expression. SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG, Index: llvm/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.cpp +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -41714,6 +41714,7 @@ char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG, bool LegalOperations, bool ForCodeSize, + bool EnableUseCheck, unsigned Depth) const { // fneg patterns are removable even if they have multiple uses. if (isFNEG(DAG, Op.getNode(), Depth)) @@ -41742,7 +41743,7 @@ // extra operand negations as well. for (int i = 0; i != 3; ++i) { char V = isNegatibleForFree(Op.getOperand(i), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, EnableUseCheck, Depth + 1); if (V == 2) return V; } @@ -41751,7 +41752,8 @@ } return TargetLowering::isNegatibleForFree(Op, DAG, LegalOperations, - ForCodeSize, Depth); + ForCodeSize, EnableUseCheck, + Depth); } SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, @@ -41783,7 +41785,7 @@ SmallVector NewOps(Op.getNumOperands(), SDValue()); for (int i = 0; i != 3; ++i) { char V = isNegatibleForFree(Op.getOperand(i), DAG, LegalOperations, - ForCodeSize, Depth + 1); + ForCodeSize, false, Depth + 1); if (V == 2) NewOps[i] = getNegatedExpression(Op.getOperand(i), DAG, LegalOperations, ForCodeSize, Depth + 1); Index: llvm/test/CodeGen/AArch64/arm64-fmadd.ll =================================================================== --- llvm/test/CodeGen/AArch64/arm64-fmadd.ll +++ llvm/test/CodeGen/AArch64/arm64-fmadd.ll @@ -88,5 +88,23 @@ ret double %0 } +; This would crash while trying getNegatedExpression(). + +define float @negated_constant(float %x) { +; CHECK-LABEL: negated_constant: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #-1037565952 +; CHECK-NEXT: mov w9, #1109917696 +; CHECK-NEXT: fmov s1, w8 +; CHECK-NEXT: fmul s1, s0, s1 +; CHECK-NEXT: fmov s2, w9 +; CHECK-NEXT: fmadd s0, s0, s2, s1 +; CHECK-NEXT: ret + %m = fmul float %x, 42.0 + %fma = call nsz float @llvm.fma.f32(float %x, float -42.0, float %m) + %nfma = fneg float %fma + ret float %nfma +} + declare float @llvm.fma.f32(float, float, float) nounwind readnone declare double @llvm.fma.f64(double, double, double) nounwind readnone Index: llvm/test/CodeGen/X86/fma-fneg-combine-2.ll =================================================================== --- llvm/test/CodeGen/X86/fma-fneg-combine-2.ll +++ llvm/test/CodeGen/X86/fma-fneg-combine-2.ll @@ -86,4 +86,24 @@ ret float %1 } +; This would crash while trying getNegatedExpression(). + +define float @negated_constant(float %x) { +; FMA3-LABEL: negated_constant: +; FMA3: # %bb.0: +; FMA3-NEXT: vmulss {{.*}}(%rip), %xmm0, %xmm1 +; FMA3-NEXT: vfmadd132ss {{.*#+}} xmm0 = (xmm0 * mem) + xmm1 +; FMA3-NEXT: retq +; +; FMA4-LABEL: negated_constant: +; FMA4: # %bb.0: +; FMA4-NEXT: vmulss {{.*}}(%rip), %xmm0, %xmm1 +; FMA4-NEXT: vfmaddss %xmm1, {{.*}}(%rip), %xmm0, %xmm0 +; FMA4-NEXT: retq + %m = fmul float %x, 42.0 + %fma = call nsz float @llvm.fma.f32(float %x, float -42.0, float %m) + %nfma = fneg float %fma + ret float %nfma +} + declare float @llvm.fma.f32(float, float, float)