diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1689,6 +1689,9 @@ /// Returns true if \p V is a constant min signed integer value. bool isMinSignedConstant(SDValue V); +/// Returns true if \p V is a neutral element of Opc with Flags. +bool isNeutralConstant(unsigned Opc, SDNodeFlags Flags, SDValue V); + /// Return the non-bitcasted source operand of \p V if it exists. /// If \p V is not a bitcasted value, it is returned as-is. SDValue peekThroughBitcasts(SDValue V); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -10735,6 +10735,50 @@ return Const != nullptr && Const->isMinSignedValue(); } +bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V) { + if (auto *Const = dyn_cast(V)) { + switch (Opcode) { + case ISD::ADD: + case ISD::OR: + case ISD::XOR: + case ISD::UMAX: + return Const->isZero(); + case ISD::MUL: + return Const->isOne(); + case ISD::AND: + case ISD::UMIN: + return Const->isAllOnes(); + case ISD::SMAX: + return Const->isMinSignedValue(); + case ISD::SMIN: + return Const->isMaxSignedValue(); + } + } else if (auto *ConstFP = dyn_cast(V)) { + switch (Opcode) { + case ISD::FADD: + if (Flags.hasNoSignedZeros()) + return ConstFP->isZero(); + return ConstFP->isExactlyValue(-0.0); + case ISD::FMUL: + return ConstFP->isExactlyValue(1.0); + case ISD::FMINNUM: + case ISD::FMAXNUM: { + // Neutral element for fminnum is NaN, Inf or FLT_MAX, depending on FMF. + EVT VT = V.getValueType(); + const fltSemantics &Semantics = SelectionDAG::EVTToAPFloatSemantics(VT); + APFloat NeutralAF = !Flags.hasNoNaNs() ? APFloat::getQNaN(Semantics) + : !Flags.hasNoInfs() ? APFloat::getInf(Semantics) + : APFloat::getLargest(Semantics); + if (Opcode == ISD::FMAXNUM) + NeutralAF.changeSign(); + + return ConstFP->isExactlyValue(NeutralAF); + } + } + } + return false; +} + SDValue llvm::peekThroughBitcasts(SDValue V) { while (V.getOpcode() == ISD::BITCAST) V = V.getOperand(0); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -7968,17 +7968,9 @@ if (!isOneConstant(ScalarV.getOperand(2))) return SDValue(); - // TODO: Deal with value other than neutral element. - auto IsRVVNeutralElement = [Opc, &DAG](SDNode *N, SDValue V) { - if (Opc == ISD::FADD && N->getFlags().hasNoSignedZeros() && - isNullFPConstant(V)) - return true; - return DAG.getNeutralElement(Opc, SDLoc(V), V.getSimpleValueType(), - N->getFlags()) == V; - }; - // Check the scalar of ScalarV is neutral element - if (!IsRVVNeutralElement(N, ScalarV.getOperand(1))) + // TODO: Deal with value other than neutral element. + if (!isNeutralConstant(N->getOpcode(), N->getFlags(), ScalarV.getOperand(1))) return SDValue(); if (!ScalarV.hasOneUse())