Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -772,7 +772,8 @@ // Specialize based on number of operands. SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT); - SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N); + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N, + const SDNodeFlags *Flags = nullptr); SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, SDValue N2, const SDNodeFlags *Flags = nullptr); SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -1032,6 +1032,10 @@ if (N) N->addUse(*this); } +static bool isUnaryOpWithFlags(unsigned Opcode) { + return false; +} + /// Returns true if the opcode is a binary operation with flags. static bool isBinOpWithFlags(unsigned Opcode) { switch (Opcode) { @@ -1054,15 +1058,34 @@ } } -/// This class is an extension of BinarySDNode -/// used from those opcodes that have associated extra flags. -class BinaryWithFlagsSDNode : public SDNode { +class GenericFlagsSDNode : public SDNode { public: SDNodeFlags Flags; + GenericFlagsSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, + SDVTList VTs, const SDNodeFlags &NodeFlags) + : SDNode(Opc, Order, dl, VTs), Flags(NodeFlags) {} + static bool classof(const SDNode *N) { + return isUnaryOpWithFlags(N->getOpcode()) || + isBinOpWithFlags(N->getOpcode()); + } +}; + +class UnaryWithFlagsSDNode : public GenericFlagsSDNode { +public: + UnaryWithFlagsSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, + SDVTList VTs, const SDNodeFlags &NodeFlags) + : GenericFlagsSDNode(Opc, Order, dl, VTs, NodeFlags) {} + static bool classof(const SDNode *N) { + return isUnaryOpWithFlags(N->getOpcode()); + } +}; + +class BinaryWithFlagsSDNode : public GenericFlagsSDNode { +public: BinaryWithFlagsSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs, const SDNodeFlags &NodeFlags) - : SDNode(Opc, Order, dl, VTs), Flags(NodeFlags) {} + : GenericFlagsSDNode(Opc, Order, dl, VTs, NodeFlags) {} static bool classof(const SDNode *N) { return isBinOpWithFlags(N->getOpcode()); Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3269,8 +3269,8 @@ if (getTarget().Options.NoNaNsFPMath) return true; - if (const BinaryWithFlagsSDNode *BF = dyn_cast(Op)) - return BF->Flags.hasNoNaNs(); + if (const GenericFlagsSDNode *N = dyn_cast(Op)) + return N->Flags.hasNoNaNs(); // If the value is a constant, we can obviously see if it is a NaN or not. if (const ConstantFPSDNode *C = dyn_cast(Op)) @@ -3387,7 +3387,7 @@ } SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, - SDValue Operand) { + SDValue Operand, const SDNodeFlags *Flags) { // Constant fold unary operations with an integer constant operand. Even // opaque constant will be folded, because the folding of unary operations // doesn't create new constants with different values. Nevertheless, the @@ -3731,10 +3731,18 @@ FoldingSetNodeID ID; AddNodeIDNode(ID, Opcode, VTs, Ops); void *IP = nullptr; - if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + if (Flags) + E->intersectFlagsWith(Flags); return SDValue(E, 0); + } + + if (Flags && isUnaryOpWithFlags(Opcode)) + N = newSDNode(Opcode, DL.getIROrder(), + DL.getDebugLoc(), VTs, *Flags); + else + N = newSDNode(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs); - N = newSDNode(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs); createOperands(N, Ops); CSEMap.InsertNode(N, IP); } else { @@ -6013,7 +6021,7 @@ unsigned NumOps = Ops.size(); switch (NumOps) { case 0: return getNode(Opcode, DL, VT); - case 1: return getNode(Opcode, DL, VT, Ops[0]); + case 1: return getNode(Opcode, DL, VT, Ops[0], Flags); case 2: return getNode(Opcode, DL, VT, Ops[0], Ops[1], Flags); case 3: return getNode(Opcode, DL, VT, Ops[0], Ops[1], Ops[2]); default: break; @@ -7423,13 +7431,13 @@ } const SDNodeFlags *SDNode::getFlags() const { - if (auto *FlagsNode = dyn_cast(this)) + if (auto *FlagsNode = dyn_cast(this)) return &FlagsNode->Flags; return nullptr; } void SDNode::intersectFlagsWith(const SDNodeFlags *Flags) { - if (auto *FlagsNode = dyn_cast(this)) + if (auto *FlagsNode = dyn_cast(this)) FlagsNode->Flags.intersectWith(Flags); }