diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1689,6 +1689,12 @@ /// Returns true if \p V is a constant min signed integer value. bool isMinSignedConstant(SDValue V); +/// Returns true if \p V is a neutral element of Opc with Flags. +/// When OperandNo is 0, it checks that V is a left identity. Otherwise, it +/// checks that V is a right identity. +bool isNeutralConstant(unsigned Opc, SDNodeFlags Flags, SDValue V, + unsigned OpereandNo); + /// Return the non-bitcasted source operand of \p V if it exists. /// If \p V is not a bitcasted value, it is returned as-is. SDValue peekThroughBitcasts(SDValue V); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -10735,6 +10735,65 @@ return Const != nullptr && Const->isMinSignedValue(); } +bool llvm::isNeutralConstant(unsigned Opcode, SDNodeFlags Flags, SDValue V, + unsigned OperandNo) { + if (auto *Const = dyn_cast(V)) { + switch (Opcode) { + case ISD::ADD: + case ISD::OR: + case ISD::XOR: + case ISD::UMAX: + return Const->isZero(); + case ISD::MUL: + return Const->isOne(); + case ISD::AND: + case ISD::UMIN: + return Const->isAllOnes(); + case ISD::SMAX: + return Const->isMinSignedValue(); + case ISD::SMIN: + return Const->isMaxSignedValue(); + case ISD::SUB: + case ISD::SHL: + case ISD::SRA: + case ISD::SRL: + return OperandNo == 1 && Const->isZero(); + case ISD::UDIV: + case ISD::SDIV: + return OperandNo == 1 && Const->isOne(); + } + } else if (auto *ConstFP = dyn_cast(V)) { + switch (Opcode) { + case ISD::FADD: + return ConstFP->isZero() && + (Flags.hasNoSignedZeros() || ConstFP->isNegative()); + case ISD::FSUB: + return ConstFP->isZero() && + (Flags.hasNoSignedZeros() || !ConstFP->isNegative()); + case ISD::FMUL: + return ConstFP->isExactlyValue(1.0); + case ISD::FDIV: + return OperandNo == 1 && ConstFP->isExactlyValue(1.0); + case ISD::FMINNUM: + case ISD::FMAXNUM: { + // Neutral element for fminnum is NaN, Inf or FLT_MAX, depending on FMF. + EVT VT = V.getValueType(); + const fltSemantics &Semantics = SelectionDAG::EVTToAPFloatSemantics(VT); + APFloat NeutralAF = !Flags.hasNoNaNs() + ? APFloat::getQNaN(Semantics) + : !Flags.hasNoInfs() + ? APFloat::getInf(Semantics) + : APFloat::getLargest(Semantics); + if (Opcode == ISD::FMAXNUM) + NeutralAF.changeSign(); + + return ConstFP->isExactlyValue(NeutralAF); + } + } + } + return false; +} + SDValue llvm::peekThroughBitcasts(SDValue V) { while (V.getOpcode() == ISD::BITCAST) V = V.getOperand(0); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -7629,17 +7629,10 @@ if (!isOneConstant(ScalarV.getOperand(2))) return SDValue(); - // TODO: Deal with value other than neutral element. - auto IsRVVNeutralElement = [Opc, &DAG](SDNode *N, SDValue V) { - if (Opc == ISD::FADD && N->getFlags().hasNoSignedZeros() && - isNullFPConstant(V)) - return true; - return DAG.getNeutralElement(Opc, SDLoc(V), V.getSimpleValueType(), - N->getFlags()) == V; - }; - // Check the scalar of ScalarV is neutral element - if (!IsRVVNeutralElement(N, ScalarV.getOperand(1))) + // TODO: Deal with value other than neutral element. + if (!isNeutralConstant(N->getOpcode(), N->getFlags(), ScalarV.getOperand(1), + 0)) return SDValue(); if (!ScalarV.hasOneUse())