Index: llvm/include/llvm/CodeGen/SelectionDAG.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAG.h +++ llvm/include/llvm/CodeGen/SelectionDAG.h @@ -331,6 +331,34 @@ virtual void anchor(); }; + /// Help to insert SDNodeFlags automatically in transforming. Use + /// RAII to save and resume flags in current scope. + class FlagInserter { + SelectionDAG &DAG; + SDNodeFlags Flags; + FlagInserter *LastInserter; + bool Disabled; + + public: + FlagInserter(SelectionDAG &SDAG, SDNode *N) + : DAG(SDAG), Flags(N->getFlags()), LastInserter(SDAG.getFlagInserter()), + Disabled(false) { + SDAG.setFlagInserter(this); + } + + FlagInserter(const FlagInserter &) = delete; + FlagInserter &operator=(const FlagInserter &) = delete; + ~FlagInserter() { DAG.setFlagInserter(LastInserter); } + + const SDNodeFlags getFlags() const { return Flags; } + + // In some cases we don't want default flags in the scope. Disable + // the inserter to give flags explicitly. + void disable() { Disabled = true; } + void enable() { Disabled = false; } + bool isDisabled() const { return Disabled; } + }; + /// When true, additional steps are taken to /// ensure that getConstant() and similar functions return DAG nodes that /// have legal types. This is important after type legalization since @@ -433,6 +461,9 @@ ProfileSummaryInfo *getPSI() const { return PSI; } BlockFrequencyInfo *getBFI() const { return BFI; } + FlagInserter *getFlagInserter() { return Inserter; } + void setFlagInserter(FlagInserter *FI) { Inserter = FI; } + /// Just dump dot graph to a user-provided path and title. /// This doesn't open the dot viewer program and /// helps visualization when outside debugging session. @@ -1998,6 +2029,8 @@ std::map, SDNode *> TargetExternalSymbols; DenseMap MCSymbols; + + FlagInserter *Inserter = nullptr; }; template <> struct GraphTraits : public GraphTraits { Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -13128,6 +13128,7 @@ EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + SelectionDAG::FlagInserter FlagsInserter(DAG, N); // FMA nodes have flags that propagate to the created nodes. const SDNodeFlags Flags = N->getFlags(); Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4410,6 +4410,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, const SDNodeFlags Flags) { + SDNodeFlags NFlags = Flags; + if (Inserter && !Inserter->isDisabled()) + NFlags = Inserter->getFlags(); // Constant fold unary operations with an integer constant operand. Even // opaque constant will be folded, because the folding of unary operations // doesn't create new constants with different values. Nevertheless, the @@ -4805,12 +4808,12 @@ AddNodeIDNode(ID, Opcode, VTs, Ops); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { - E->intersectFlagsWith(Flags); + E->intersectFlagsWith(NFlags); return SDValue(E, 0); } N = newSDNode(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs); - N->setFlags(Flags); + N->setFlags(NFlags); createOperands(N, Ops); CSEMap.InsertNode(N, IP); } else { @@ -5218,6 +5221,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, SDValue N2, const SDNodeFlags Flags) { + SDNodeFlags NFlags = Flags; + if (Inserter && !Inserter->isDisabled()) + NFlags = Inserter->getFlags(); ConstantSDNode *N1C = dyn_cast(N1); ConstantSDNode *N2C = dyn_cast(N2); ConstantFPSDNode *N1CFP = dyn_cast(N1); @@ -5284,7 +5290,7 @@ assert(VT.isInteger() && "This operator does not apply to FP types!"); assert(N1.getValueType() == N2.getValueType() && N1.getValueType() == VT && "Binary operator types must match!"); - if (N2C && (N1.getOpcode() == ISD::VSCALE) && Flags.hasNoSignedWrap()) { + if (N2C && (N1.getOpcode() == ISD::VSCALE) && NFlags.hasNoSignedWrap()) { APInt MulImm = cast(N1->getOperand(0))->getAPIntValue(); APInt N2CImm = N2C->getAPIntValue(); return getVScale(DL, VT, MulImm * N2CImm); @@ -5316,7 +5322,7 @@ assert(VT.isFloatingPoint() && "This operator only applies to FP types!"); assert(N1.getValueType() == N2.getValueType() && N1.getValueType() == VT && "Binary operator types must match!"); - if (SDValue V = simplifyFPBinop(Opcode, N1, N2, Flags)) + if (SDValue V = simplifyFPBinop(Opcode, N1, N2, NFlags)) return V; break; case ISD::FCOPYSIGN: // N1 and result must match. N1/N2 need not match. @@ -5326,7 +5332,7 @@ "Invalid FCOPYSIGN!"); break; case ISD::SHL: - if (N2C && (N1.getOpcode() == ISD::VSCALE) && Flags.hasNoSignedWrap()) { + if (N2C && (N1.getOpcode() == ISD::VSCALE) && NFlags.hasNoSignedWrap()) { APInt MulImm = cast(N1->getOperand(0))->getAPIntValue(); APInt ShiftImm = N2C->getAPIntValue(); return getVScale(DL, VT, MulImm << ShiftImm); @@ -5640,12 +5646,12 @@ AddNodeIDNode(ID, Opcode, VTs, Ops); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { - E->intersectFlagsWith(Flags); + E->intersectFlagsWith(NFlags); return SDValue(E, 0); } N = newSDNode(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs); - N->setFlags(Flags); + N->setFlags(NFlags); createOperands(N, Ops); CSEMap.InsertNode(N, IP); } else { @@ -5662,6 +5668,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, SDValue N2, SDValue N3, const SDNodeFlags Flags) { + SDNodeFlags NFlags = Flags; + if (Inserter && !Inserter->isDisabled()) + NFlags = Inserter->getFlags(); // Perform various simplifications. switch (Opcode) { case ISD::FMA: { @@ -5789,12 +5798,12 @@ AddNodeIDNode(ID, Opcode, VTs, Ops); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { - E->intersectFlagsWith(Flags); + E->intersectFlagsWith(NFlags); return SDValue(E, 0); } N = newSDNode(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs); - N->setFlags(Flags); + N->setFlags(NFlags); createOperands(N, Ops); CSEMap.InsertNode(N, IP); } else { @@ -7483,12 +7492,15 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef Ops, const SDNodeFlags Flags) { + SDNodeFlags NFlags = Flags; + if (Inserter && !Inserter->isDisabled()) + NFlags = Inserter->getFlags(); unsigned NumOps = Ops.size(); switch (NumOps) { case 0: return getNode(Opcode, DL, VT); - case 1: return getNode(Opcode, DL, VT, Ops[0], Flags); - case 2: return getNode(Opcode, DL, VT, Ops[0], Ops[1], Flags); - case 3: return getNode(Opcode, DL, VT, Ops[0], Ops[1], Ops[2], Flags); + case 1: return getNode(Opcode, DL, VT, Ops[0], NFlags); + case 2: return getNode(Opcode, DL, VT, Ops[0], Ops[1], NFlags); + case 3: return getNode(Opcode, DL, VT, Ops[0], Ops[1], Ops[2], NFlags); default: break; } @@ -7540,7 +7552,7 @@ createOperands(N, Ops); } - N->setFlags(Flags); + N->setFlags(NFlags); InsertNode(N); SDValue V(N, 0); NewSDValueDbgMsg(V, "Creating new node: ", this); @@ -7554,6 +7566,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList, ArrayRef Ops, const SDNodeFlags Flags) { + SDNodeFlags NFlags = Flags; + if (Inserter && !Inserter->isDisabled()) + NFlags = Inserter->getFlags(); if (VTList.NumVTs == 1) return getNode(Opcode, DL, VTList.VTs[0], Ops); @@ -7629,7 +7644,7 @@ createOperands(N, Ops); } - N->setFlags(Flags); + N->setFlags(NFlags); InsertNode(N); SDValue V(N, 0); NewSDValueDbgMsg(V, "Creating new node: ", this); @@ -8251,12 +8266,15 @@ SDNode *SelectionDAG::getNodeIfExists(unsigned Opcode, SDVTList VTList, ArrayRef Ops, const SDNodeFlags Flags) { + SDNodeFlags NFlags = Flags; + if (Inserter && !Inserter->isDisabled()) + NFlags = Inserter->getFlags(); if (VTList.VTs[VTList.NumVTs - 1] != MVT::Glue) { FoldingSetNodeID ID; AddNodeIDNode(ID, Opcode, VTList, Ops); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP)) { - E->intersectFlagsWith(Flags); + E->intersectFlagsWith(NFlags); return E; } } Index: llvm/test/CodeGen/PowerPC/fma-combine.ll =================================================================== --- llvm/test/CodeGen/PowerPC/fma-combine.ll +++ llvm/test/CodeGen/PowerPC/fma-combine.ll @@ -243,17 +243,18 @@ define double @fma_flag_propagation(double %a) { ; CHECK-FAST-LABEL: fma_flag_propagation: ; CHECK-FAST: # %bb.0: # %entry -; CHECK-FAST-NEXT: xssubdp 1, 1, 1 +; CHECK-FAST-NEXT: xxlxor 1, 1, 1 ; CHECK-FAST-NEXT: blr ; ; CHECK-FAST-NOVSX-LABEL: fma_flag_propagation: ; CHECK-FAST-NOVSX: # %bb.0: # %entry -; CHECK-FAST-NOVSX-NEXT: fsub 1, 1, 1 +; CHECK-FAST-NOVSX-NEXT: addis 3, 2, .LCPI6_0@toc@ha +; CHECK-FAST-NOVSX-NEXT: lfs 1, .LCPI6_0@toc@l(3) ; CHECK-FAST-NOVSX-NEXT: blr ; ; CHECK-LABEL: fma_flag_propagation: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: xssubdp 1, 1, 1 +; CHECK-NEXT: xxlxor 1, 1, 1 ; CHECK-NEXT: blr entry: %0 = fneg double %a