Index: llvm/include/llvm/CodeGen/SelectionDAG.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAG.h +++ llvm/include/llvm/CodeGen/SelectionDAG.h @@ -331,6 +331,28 @@ virtual void anchor(); }; + /// Help to insert SDNodeFlags automatically in transforming. Use + /// RAII to save and resume flags in current scope. + class FlagInserter { + SelectionDAG &DAG; + SDNodeFlags Flags; + FlagInserter *LastInserter; + bool Disabled; + + public: + FlagInserter(SelectionDAG &SDAG, SDNode *N) + : DAG(SDAG), Flags(N->getFlags()), LastInserter(SDAG.getFlagInserter()), + Disabled(false) { + SDAG.setFlagInserter(this); + } + + FlagInserter(const FlagInserter &) = delete; + FlagInserter &operator=(const FlagInserter &) = delete; + ~FlagInserter() { DAG.setFlagInserter(LastInserter); } + + const SDNodeFlags getFlags() const { return Flags; } + }; + /// When true, additional steps are taken to /// ensure that getConstant() and similar functions return DAG nodes that /// have legal types. This is important after type legalization since @@ -433,6 +455,9 @@ ProfileSummaryInfo *getPSI() const { return PSI; } BlockFrequencyInfo *getBFI() const { return BFI; } + FlagInserter *getFlagInserter() { return Inserter; } + void setFlagInserter(FlagInserter *FI) { Inserter = FI; } + /// Just dump dot graph to a user-provided path and title. /// This doesn't open the dot viewer program and /// helps visualization when outside debugging session. @@ -945,21 +970,31 @@ SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef Ops); SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, - ArrayRef Ops, const SDNodeFlags Flags = SDNodeFlags()); + ArrayRef Ops, const SDNodeFlags Flags); SDValue getNode(unsigned Opcode, const SDLoc &DL, ArrayRef ResultTys, ArrayRef Ops); SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList, - ArrayRef Ops, const SDNodeFlags Flags = SDNodeFlags()); + ArrayRef Ops, const SDNodeFlags Flags); + + // Use flags from current flag inserter. + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, + ArrayRef Ops); + SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList, + ArrayRef Ops); + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand); + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2); + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3); // Specialize based on number of operands. SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT); SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, - const SDNodeFlags Flags = SDNodeFlags()); + const SDNodeFlags Flags); SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, - SDValue N2, const SDNodeFlags Flags = SDNodeFlags()); + SDValue N2, const SDNodeFlags Flags); SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, - SDValue N2, SDValue N3, - const SDNodeFlags Flags = SDNodeFlags()); + SDValue N2, SDValue N3, const SDNodeFlags Flags); SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, SDValue N2, SDValue N3, SDValue N4); SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, @@ -1468,8 +1503,10 @@ SDValue Operand, SDValue Subreg); /// Get the specified node if it's already available, or else return NULL. - SDNode *getNodeIfExists(unsigned Opcode, SDVTList VTList, ArrayRef Ops, - const SDNodeFlags Flags = SDNodeFlags()); + SDNode *getNodeIfExists(unsigned Opcode, SDVTList VTList, + ArrayRef Ops, const SDNodeFlags Flags); + SDNode *getNodeIfExists(unsigned Opcode, SDVTList VTList, + ArrayRef Ops); /// Creates a SDDbgValue node. SDDbgValue *getDbgValue(DIVariable *Var, DIExpression *Expr, SDNode *N, @@ -1998,6 +2035,8 @@ std::map, SDNode *> TargetExternalSymbols; DenseMap MCSymbols; + + FlagInserter *Inserter = nullptr; }; template <> struct GraphTraits : public GraphTraits { Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12724,6 +12724,7 @@ SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; const SDNodeFlags Flags = N->getFlags(); + SelectionDAG::FlagInserter FlagsInserter(DAG, N); if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) return R; @@ -12931,6 +12932,7 @@ SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; const SDNodeFlags Flags = N->getFlags(); + SelectionDAG::FlagInserter FlagsInserter(DAG, N); if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) return R; @@ -13011,6 +13013,7 @@ SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; const SDNodeFlags Flags = N->getFlags(); + SelectionDAG::FlagInserter FlagsInserter(DAG, N); if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) return R; @@ -13153,6 +13156,7 @@ EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + SelectionDAG::FlagInserter FlagsInserter(DAG, N); // FMA nodes have flags that propagate to the created nodes. const SDNodeFlags Flags = N->getFlags(); @@ -13356,6 +13360,7 @@ SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; SDNodeFlags Flags = N->getFlags(); + SelectionDAG::FlagInserter FlagsInserter(DAG, N); if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) return R; @@ -13497,6 +13502,7 @@ ConstantFPSDNode *N1CFP = dyn_cast(N1); EVT VT = N->getValueType(0); SDNodeFlags Flags = N->getFlags(); + SelectionDAG::FlagInserter FlagsInserter(DAG, N); if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) return R; @@ -13514,6 +13520,7 @@ SDValue DAGCombiner::visitFSQRT(SDNode *N) { SDNodeFlags Flags = N->getFlags(); const TargetOptions &Options = DAG.getTarget().Options; + SelectionDAG::FlagInserter FlagsInserter(DAG, N); // Require 'ninf' flag since sqrt(+Inf) = +Inf, but the estimation goes as: // sqrt(+Inf) == rsqrt(+Inf) * +Inf = 0 * +Inf = NaN @@ -13598,6 +13605,8 @@ } SDValue DAGCombiner::visitFPOW(SDNode *N) { + SelectionDAG::FlagInserter FlagsInserter(DAG, N); + ConstantFPSDNode *ExponentC = isConstOrConstSplatFP(N->getOperand(1)); if (!ExponentC) return SDValue(); @@ -14029,6 +14038,7 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + SelectionDAG::FlagInserter FlagsInserter(DAG, N); // Constant fold FNEG. if (isConstantFPBuildVectorOrConstantFP(N0)) @@ -14063,6 +14073,7 @@ EVT VT = N->getValueType(0); const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); + SelectionDAG::FlagInserter FlagsInserter(DAG, N); if (N0CFP && N1CFP) { const APFloat &C0 = N0CFP->getValueAPF(); @@ -14104,6 +14115,7 @@ SDValue DAGCombiner::visitFABS(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); + SelectionDAG::FlagInserter FlagsInserter(DAG, N); // fold (fabs c1) -> fabs(c1) if (isConstantFPBuildVectorOrConstantFP(N0)) Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4394,6 +4394,14 @@ return V; } +SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, + SDValue Operand) { + SDNodeFlags Flags; + if (Inserter) + Flags = Inserter->getFlags(); + return getNode(Opcode, DL, VT, Operand, Flags); +} + SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, const SDNodeFlags Flags) { // Constant fold unary operations with an integer constant operand. Even @@ -5202,6 +5210,14 @@ return V; } +SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, + SDValue N1, SDValue N2) { + SDNodeFlags Flags; + if (Inserter) + Flags = Inserter->getFlags(); + return getNode(Opcode, DL, VT, N1, N2, Flags); +} + SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, SDValue N2, const SDNodeFlags Flags) { ConstantSDNode *N1C = dyn_cast(N1); @@ -5645,6 +5661,14 @@ return V; } +SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, + SDValue N1, SDValue N2, SDValue N3) { + SDNodeFlags Flags; + if (Inserter) + Flags = Inserter->getFlags(); + return getNode(Opcode, DL, VT, N1, N2, N3, Flags); +} + SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, SDValue N2, SDValue N3, const SDNodeFlags Flags) { @@ -7467,6 +7491,14 @@ return getNode(Opcode, DL, VT, NewOps); } +SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, + ArrayRef Ops) { + SDNodeFlags Flags; + if (Inserter) + Flags = Inserter->getFlags(); + return getNode(Opcode, DL, VT, Ops, Flags); +} + SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef Ops, const SDNodeFlags Flags) { unsigned NumOps = Ops.size(); @@ -7538,6 +7570,14 @@ return getNode(Opcode, DL, getVTList(ResultTys), Ops); } +SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList, + ArrayRef Ops) { + SDNodeFlags Flags; + if (Inserter) + Flags = Inserter->getFlags(); + return getNode(Opcode, DL, VTList, Ops, Flags); +} + SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList, ArrayRef Ops, const SDNodeFlags Flags) { if (VTList.NumVTs == 1) @@ -8234,6 +8274,14 @@ /// getNodeIfExists - Get the specified node if it's already available, or /// else return NULL. +SDNode *SelectionDAG::getNodeIfExists(unsigned Opcode, SDVTList VTList, + ArrayRef Ops) { + SDNodeFlags Flags; + if (Inserter) + Flags = Inserter->getFlags(); + return getNodeIfExists(Opcode, VTList, Ops, Flags); +} + SDNode *SelectionDAG::getNodeIfExists(unsigned Opcode, SDVTList VTList, ArrayRef Ops, const SDNodeFlags Flags) { Index: llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll =================================================================== --- llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll +++ llvm/test/CodeGen/X86/sqrt-fastmath-mir.ll @@ -22,7 +22,7 @@ ; CHECK: liveins: $xmm0 ; CHECK: [[COPY:%[0-9]+]]:fr32 = COPY $xmm0 ; CHECK: [[DEF:%[0-9]+]]:fr32 = IMPLICIT_DEF - ; CHECK: [[VRSQRTSSr:%[0-9]+]]:fr32 = VRSQRTSSr killed [[DEF]], [[COPY]] + ; CHECK: [[VRSQRTSSr:%[0-9]+]]:fr32 = ninf VRSQRTSSr killed [[DEF]], [[COPY]] ; CHECK: %3:fr32 = ninf nofpexcept VMULSSrr [[COPY]], [[VRSQRTSSr]], implicit $mxcsr ; CHECK: [[VMOVSSrm_alt:%[0-9]+]]:fr32 = VMOVSSrm_alt $rip, 1, $noreg, %const.0, $noreg :: (load 4 from constant-pool) ; CHECK: %5:fr32 = ninf nofpexcept VFMADD213SSr [[VRSQRTSSr]], killed %3, [[VMOVSSrm_alt]], implicit $mxcsr @@ -67,7 +67,7 @@ ; CHECK: liveins: $xmm0 ; CHECK: [[COPY:%[0-9]+]]:fr32 = COPY $xmm0 ; CHECK: [[DEF:%[0-9]+]]:fr32 = IMPLICIT_DEF - ; CHECK: [[VRSQRTSSr:%[0-9]+]]:fr32 = VRSQRTSSr killed [[DEF]], [[COPY]] + ; CHECK: [[VRSQRTSSr:%[0-9]+]]:fr32 = ninf VRSQRTSSr killed [[DEF]], [[COPY]] ; CHECK: %3:fr32 = ninf nofpexcept VMULSSrr [[COPY]], [[VRSQRTSSr]], implicit $mxcsr ; CHECK: [[VMOVSSrm_alt:%[0-9]+]]:fr32 = VMOVSSrm_alt $rip, 1, $noreg, %const.0, $noreg :: (load 4 from constant-pool) ; CHECK: %5:fr32 = ninf nofpexcept VFMADD213SSr [[VRSQRTSSr]], killed %3, [[VMOVSSrm_alt]], implicit $mxcsr @@ -96,7 +96,7 @@ ; CHECK: liveins: $xmm0 ; CHECK: [[COPY:%[0-9]+]]:fr32 = COPY $xmm0 ; CHECK: [[DEF:%[0-9]+]]:fr32 = IMPLICIT_DEF - ; CHECK: [[VRSQRTSSr:%[0-9]+]]:fr32 = VRSQRTSSr killed [[DEF]], [[COPY]] + ; CHECK: [[VRSQRTSSr:%[0-9]+]]:fr32 = nnan ninf nsz arcp contract afn reassoc VRSQRTSSr killed [[DEF]], [[COPY]] ; CHECK: %3:fr32 = nnan ninf nsz arcp contract afn reassoc nofpexcept VMULSSrr [[COPY]], [[VRSQRTSSr]], implicit $mxcsr ; CHECK: [[VMOVSSrm_alt:%[0-9]+]]:fr32 = VMOVSSrm_alt $rip, 1, $noreg, %const.0, $noreg :: (load 4 from constant-pool) ; CHECK: %5:fr32 = nnan ninf nsz arcp contract afn reassoc nofpexcept VFMADD213SSr [[VRSQRTSSr]], killed %3, [[VMOVSSrm_alt]], implicit $mxcsr @@ -120,7 +120,7 @@ ; CHECK: liveins: $xmm0 ; CHECK: [[COPY:%[0-9]+]]:fr32 = COPY $xmm0 ; CHECK: [[DEF:%[0-9]+]]:fr32 = IMPLICIT_DEF - ; CHECK: [[VRSQRTSSr:%[0-9]+]]:fr32 = VRSQRTSSr killed [[DEF]], [[COPY]] + ; CHECK: [[VRSQRTSSr:%[0-9]+]]:fr32 = nnan ninf nsz arcp contract afn reassoc VRSQRTSSr killed [[DEF]], [[COPY]] ; CHECK: %3:fr32 = nnan ninf nsz arcp contract afn reassoc nofpexcept VMULSSrr [[COPY]], [[VRSQRTSSr]], implicit $mxcsr ; CHECK: [[VMOVSSrm_alt:%[0-9]+]]:fr32 = VMOVSSrm_alt $rip, 1, $noreg, %const.0, $noreg :: (load 4 from constant-pool) ; CHECK: %5:fr32 = nnan ninf nsz arcp contract afn reassoc nofpexcept VFMADD213SSr [[VRSQRTSSr]], killed %3, [[VMOVSSrm_alt]], implicit $mxcsr