diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -538,9 +538,7 @@ SDValue visitVECREDUCE(SDNode *N); SDValue visitVPOp(SDNode *N); - template SDValue visitFADDForFMACombine(SDNode *N); - template SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -864,38 +862,66 @@ void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); } }; -class EmptyMatchContext { +class MatchContext { +protected: SelectionDAG &DAG; const TargetLowering &TLI; public: - EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) + MatchContext(SelectionDAG &DAG, const TargetLowering &TLI) : DAG(DAG), TLI(TLI) {} - bool match(SDValue OpN, unsigned Opcode) const { + static std::unique_ptr + get(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root); + + virtual bool match(SDValue OpN, unsigned Opcode) const { return Opcode == OpN->getOpcode(); } // Same as SelectionDAG::getNode(). - template SDValue getNode(ArgT &&...Args) { - return DAG.getNode(std::forward(Args)...); + virtual SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, + SDValue Operand) { + return DAG.getNode(Opcode, DL, VT, Operand); } - bool isOperationLegalOrCustom(unsigned Op, EVT VT, - bool LegalOnly = false) const { + virtual SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2) { + return DAG.getNode(Opcode, DL, VT, N1, N2); + } + + virtual SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3); + } + + virtual SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, + SDValue Operand, SDNodeFlags Flags) { + return DAG.getNode(Opcode, DL, VT, Operand, Flags); + } + + virtual SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDNodeFlags Flags) { + return DAG.getNode(Opcode, DL, VT, N1, N2, Flags); + } + + virtual SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDNodeFlags Flags) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, Flags); + } + + virtual bool isOperationLegalOrCustom(unsigned Op, EVT VT, + bool LegalOnly = false) const { return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly); } }; -class VPMatchContext { - SelectionDAG &DAG; - const TargetLowering &TLI; +class VPMatchContext : public MatchContext { SDValue RootMaskOp; SDValue RootVectorLenOp; public: VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) - : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() { + : MatchContext(DAG, TLI) { assert(Root->isVPOpcode()); if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode())) RootMaskOp = Root->getOperand(*RootMaskPos); @@ -907,12 +933,12 @@ /// whether \p OpVal is a node that is functionally compatible with the /// NodeType \p Opc - bool match(SDValue OpVal, unsigned Opc) const { + bool match(SDValue OpVal, unsigned Opc) const override { if (!OpVal->isVPOpcode()) return OpVal->getOpcode() == Opc; - auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), - !OpVal->getFlags().hasNoFPExcept()); + auto BaseOpc = ISD::getBaseOpcodeForVP( + OpVal->getOpcode(), !OpVal->getFlags().hasNoFPExcept()); if (BaseOpc != Opc) return false; @@ -936,7 +962,8 @@ // TODO emit VP intrinsics where MaskOp/VectorLenOp != null // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return // DAG.getNode(Opcode, DL, VT); } - SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) { + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, + SDValue Operand) override { unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); assert(ISD::getVPMaskIdx(VPOpcode) == 1 && ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); @@ -945,7 +972,7 @@ } SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, - SDValue N2) { + SDValue N2) override { unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); assert(ISD::getVPMaskIdx(VPOpcode) == 2 && ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); @@ -954,7 +981,7 @@ } SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, - SDValue N2, SDValue N3) { + SDValue N2, SDValue N3) override { unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); assert(ISD::getVPMaskIdx(VPOpcode) == 3 && ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); @@ -963,25 +990,25 @@ } SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, - SDNodeFlags Flags) { + SDNodeFlags Flags) override { unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); assert(ISD::getVPMaskIdx(VPOpcode) == 1 && ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2); - return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, - Flags); + return DAG.getNode(VPOpcode, DL, VT, + {Operand, RootMaskOp, RootVectorLenOp}, Flags); } SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, - SDValue N2, SDNodeFlags Flags) { + SDValue N2, SDNodeFlags Flags) override { unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); assert(ISD::getVPMaskIdx(VPOpcode) == 2 && ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3); - return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, - Flags); + return DAG.getNode(VPOpcode, DL, VT, + {N1, N2, RootMaskOp, RootVectorLenOp}, Flags); } SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, - SDValue N2, SDValue N3, SDNodeFlags Flags) { + SDValue N2, SDValue N3, SDNodeFlags Flags) override { unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode); assert(ISD::getVPMaskIdx(VPOpcode) == 3 && ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4); @@ -990,12 +1017,19 @@ } bool isOperationLegalOrCustom(unsigned Op, EVT VT, - bool LegalOnly = false) const { + bool LegalOnly = false) const override { unsigned VPOp = ISD::getVPForBaseOpcode(Op); return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly); } }; +std::unique_ptr +MatchContext::get(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) { + if (Root->isVPOpcode()) + return std::make_unique(DAG, TLI, Root); + return std::make_unique(DAG, TLI); +} + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -15141,26 +15175,24 @@ } /// Try to perform FMA combining on a given FADD node. -template SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); - MatchContextClass matcher(DAG, TLI, N); + std::unique_ptr matcher = MatchContext::get(DAG, TLI, N); const TargetOptions &Options = DAG.getTarget().Options; - bool UseVP = std::is_same_v; - // Floating-point multiply-add with intermediate rounding. - // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext. + // FIXME: Make isFMADLegal have specific behavior when N is a vp node. // FIXME: Add VP_FMAD opcode. - bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N)); + bool HasFMAD = + !N->isVPOpcode() && (LegalOperations && TLI.isFMADLegal(DAG, N)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && - (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)); + (!LegalOperations || matcher->isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) @@ -15182,13 +15214,13 @@ bool Aggressive = TLI.enableAggressiveFMAFusion(VT); auto isFusedOp = [&](SDValue N) { - return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD); + return matcher->match(N, ISD::FMA) || matcher->match(N, ISD::FMAD); }; // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { - if (!matcher.match(N, ISD::FMUL)) + if (!matcher->match(N, ISD::FMUL)) return false; return AllowFusionGlobally || N->getFlags().hasAllowContract(); }; @@ -15201,15 +15233,15 @@ // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), - N0.getOperand(1), N1); + return matcher->getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), + N0.getOperand(1), N1); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), - N1.getOperand(1), N0); + return matcher->getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), + N1.getOperand(1), N0); } // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E) @@ -15250,29 +15282,29 @@ // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (matcher.match(N0, ISD::FP_EXTEND)) { + if (matcher->match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1); + matcher->getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1); } } // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) // Note: Commutes FADD operands. - if (matcher.match(N1, ISD::FP_EXTEND)) { + if (matcher->match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N10.getValueType())) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); + matcher->getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); } } @@ -15282,15 +15314,15 @@ // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, X, Y, - matcher.getNode(PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, U), - matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); + matcher->getNode(PreferredFusedOpcode, SL, VT, + matcher->getNode(ISD::FP_EXTEND, SL, VT, U), + matcher->getNode(ISD::FP_EXTEND, SL, VT, V), Z)); }; if (isFusedOp(N0)) { SDValue N02 = N0.getOperand(2); - if (matcher.match(N02, ISD::FP_EXTEND)) { + if (matcher->match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, @@ -15309,13 +15341,13 @@ // interesting for all targets, especially GPUs. auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, X), - matcher.getNode(ISD::FP_EXTEND, SL, VT, Y), - matcher.getNode(PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, U), - matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z)); + matcher->getNode(ISD::FP_EXTEND, SL, VT, X), + matcher->getNode(ISD::FP_EXTEND, SL, VT, Y), + matcher->getNode(PreferredFusedOpcode, SL, VT, + matcher->getNode(ISD::FP_EXTEND, SL, VT, U), + matcher->getNode(ISD::FP_EXTEND, SL, VT, V), Z)); }; if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); @@ -15371,26 +15403,24 @@ } /// Try to perform FMA combining on a given FSUB node. -template SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); - MatchContextClass matcher(DAG, TLI, N); + std::unique_ptr matcher = MatchContext::get(DAG, TLI, N); const TargetOptions &Options = DAG.getTarget().Options; - bool UseVP = std::is_same_v; - // Floating-point multiply-add with intermediate rounding. - // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext. + // FIXME: Make isFMADLegal have specific behavior when N is a vp node. // FIXME: Add VP_FMAD opcode. - bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N)); + bool HasFMAD = + !N->isVPOpcode() && (LegalOperations && TLI.isFMADLegal(DAG, N)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && - (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)); + (!LegalOperations || matcher->isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) @@ -15415,7 +15445,7 @@ // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { - if (!matcher.match(N, ISD::FMUL)) + if (!matcher->match(N, ISD::FMUL)) return false; return AllowFusionGlobally || N->getFlags().hasAllowContract(); }; @@ -15423,9 +15453,9 @@ // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) { if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) { - return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0), - XY.getOperand(1), - matcher.getNode(ISD::FNEG, SL, VT, Z)); + return matcher->getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0), + XY.getOperand(1), + matcher->getNode(ISD::FNEG, SL, VT, Z)); } return SDValue(); }; @@ -15434,9 +15464,9 @@ // Note: Commutes FSUB operands. auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) { if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)), + matcher->getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)), YZ.getOperand(1), X); } return SDValue(); @@ -15462,46 +15492,46 @@ } // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z)) - if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) && + if (matcher->match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) && (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) { SDValue N00 = N0.getOperand(0).getOperand(0); SDValue N01 = N0.getOperand(0).getOperand(1); - return matcher.getNode(PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FNEG, SL, VT, N00), N01, - matcher.getNode(ISD::FNEG, SL, VT, N1)); + return matcher->getNode(PreferredFusedOpcode, SL, VT, + matcher->getNode(ISD::FNEG, SL, VT, N00), N01, + matcher->getNode(ISD::FNEG, SL, VT, N1)); } // Look through FP_EXTEND nodes to do more combining. // fold (fsub (fpext (fmul x, y)), z) // -> (fma (fpext x), (fpext y), (fneg z)) - if (matcher.match(N0, ISD::FP_EXTEND)) { + if (matcher->match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - matcher.getNode(ISD::FNEG, SL, VT, N1)); + matcher->getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), + matcher->getNode(ISD::FNEG, SL, VT, N1)); } } // fold (fsub x, (fpext (fmul y, z))) // -> (fma (fneg (fpext y)), (fpext z), x) // Note: Commutes FSUB operands. - if (matcher.match(N1, ISD::FP_EXTEND)) { + if (matcher->match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N10.getValueType())) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode( + matcher->getNode( ISD::FNEG, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); + matcher->getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); } } @@ -15511,19 +15541,19 @@ // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent // from implementing the canonicalization in visitFSUB. - if (matcher.match(N0, ISD::FP_EXTEND)) { + if (matcher->match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); - if (matcher.match(N00, ISD::FNEG)) { + if (matcher->match(N00, ISD::FNEG)) { SDValue N000 = N00.getOperand(0); if (isContractableFMUL(N000) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return matcher.getNode( + return matcher->getNode( ISD::FNEG, SL, VT, - matcher.getNode( + matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), N1)); } } @@ -15535,19 +15565,19 @@ // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent // from implementing the canonicalization in visitFSUB. - if (matcher.match(N0, ISD::FNEG)) { + if (matcher->match(N0, ISD::FNEG)) { SDValue N00 = N0.getOperand(0); - if (matcher.match(N00, ISD::FP_EXTEND)) { + if (matcher->match(N00, ISD::FP_EXTEND)) { SDValue N000 = N00.getOperand(0); if (isContractableFMUL(N000) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N000.getValueType())) { - return matcher.getNode( + return matcher->getNode( ISD::FNEG, SL, VT, - matcher.getNode( + matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), N1)); } } @@ -15563,7 +15593,7 @@ }; auto isFusedOp = [&](SDValue N) { - return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD); + return matcher->match(N, ISD::FMA) || matcher->match(N, ISD::FMAD); }; // More folding opportunities when target permits. @@ -15574,12 +15604,12 @@ if (CanFuse && isFusedOp(N0) && isContractableAndReassociableFMUL(N0.getOperand(2)) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - matcher.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(2).getOperand(0), - N0.getOperand(2).getOperand(1), - matcher.getNode(ISD::FNEG, SL, VT, N1))); + matcher->getNode(PreferredFusedOpcode, SL, VT, + N0.getOperand(2).getOperand(0), + N0.getOperand(2).getOperand(1), + matcher->getNode(ISD::FNEG, SL, VT, N1))); } // fold (fsub x, (fma y, z, (fmul u, v))) @@ -15589,30 +15619,30 @@ N1->hasOneUse() && NoSignedZero) { SDValue N20 = N1.getOperand(2).getOperand(0); SDValue N21 = N1.getOperand(2).getOperand(1); - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + matcher->getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - matcher.getNode(PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0)); + matcher->getNode(PreferredFusedOpcode, SL, VT, + matcher->getNode(ISD::FNEG, SL, VT, N20), N21, N0)); } // fold (fsub (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) if (isFusedOp(N0) && N0->hasOneUse()) { SDValue N02 = N0.getOperand(2); - if (matcher.match(N02, ISD::FP_EXTEND)) { + if (matcher->match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableAndReassociableFMUL(N020) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N020.getValueType())) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - matcher.getNode( + matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)), - matcher.getNode(ISD::FNEG, SL, VT, N1))); + matcher->getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)), + matcher->getNode(ISD::FNEG, SL, VT, N1))); } } } @@ -15623,29 +15653,29 @@ // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - if (matcher.match(N0, ISD::FP_EXTEND)) { + if (matcher->match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); if (isFusedOp(N00)) { SDValue N002 = N00.getOperand(2); if (isContractableAndReassociableFMUL(N002) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - matcher.getNode( + matcher->getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), + matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)), - matcher.getNode(ISD::FNEG, SL, VT, N1))); + matcher->getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)), + matcher->getNode(ISD::FNEG, SL, VT, N1))); } } } // fold (fsub x, (fma y, z, (fpext (fmul u, v)))) // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x)) - if (isFusedOp(N1) && matcher.match(N1.getOperand(2), ISD::FP_EXTEND) && + if (isFusedOp(N1) && matcher->match(N1.getOperand(2), ISD::FP_EXTEND) && N1->hasOneUse()) { SDValue N120 = N1.getOperand(2).getOperand(0); if (isContractableAndReassociableFMUL(N120) && @@ -15653,15 +15683,16 @@ N120.getValueType())) { SDValue N1200 = N120.getOperand(0); SDValue N1201 = N120.getOperand(1); - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + matcher->getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - matcher.getNode( - PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FNEG, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0)); + matcher->getNode(PreferredFusedOpcode, SL, VT, + matcher->getNode(ISD::FNEG, SL, VT, + matcher->getNode(ISD::FP_EXTEND, + SL, VT, N1200)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N1201), + N0)); } } @@ -15671,7 +15702,7 @@ // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) { + if (matcher->match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) { SDValue CvtSrc = N1.getOperand(0); SDValue N100 = CvtSrc.getOperand(0); SDValue N101 = CvtSrc.getOperand(1); @@ -15681,16 +15712,17 @@ CvtSrc.getValueType())) { SDValue N1020 = N102.getOperand(0); SDValue N1021 = N102.getOperand(1); - return matcher.getNode( + return matcher->getNode( PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FNEG, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N101), - matcher.getNode( - PreferredFusedOpcode, SL, VT, - matcher.getNode(ISD::FNEG, SL, VT, - matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)), - matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0)); + matcher->getNode(ISD::FNEG, SL, VT, + matcher->getNode(ISD::FP_EXTEND, SL, VT, N100)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N101), + matcher->getNode(PreferredFusedOpcode, SL, VT, + matcher->getNode(ISD::FNEG, SL, VT, + matcher->getNode(ISD::FP_EXTEND, + SL, VT, N1020)), + matcher->getNode(ISD::FP_EXTEND, SL, VT, N1021), + N0)); } } } @@ -15797,7 +15829,7 @@ SelectionDAG::FlagInserter FlagsInserter(DAG, N); // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -15989,7 +16021,7 @@ } // enable-unsafe-fp-math // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -16099,7 +16131,7 @@ return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1); // FSUB -> FMA combines: - if (SDValue Fused = visitFSUBForFMACombine(N)) { + if (SDValue Fused = visitFSUBForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -25497,7 +25529,7 @@ SelectionDAG::FlagInserter FlagsInserter(DAG, N); // FSUB -> FMA combines: - if (SDValue Fused = visitFSUBForFMACombine(N)) { + if (SDValue Fused = visitFSUBForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; }