diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -852,9 +852,11 @@ class EmptyMatchContext { SelectionDAG &DAG; + const TargetLowering &TLI; public: - EmptyMatchContext(SelectionDAG &DAG, SDNode *Root) : DAG(DAG) {} + EmptyMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) + : DAG(DAG), TLI(TLI) {} bool match(SDValue OpN, unsigned Opcode) const { return Opcode == OpN->getOpcode(); @@ -864,16 +866,22 @@ template SDValue getNode(ArgT &&...Args) { return DAG.getNode(std::forward(Args)...); } + + bool isOperationLegalOrCustom(unsigned Op, EVT VT, + bool LegalOnly = false) const { + return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly); + } }; class VPMatchContext { SelectionDAG &DAG; + const TargetLowering &TLI; SDValue RootMaskOp; SDValue RootVectorLenOp; public: - VPMatchContext(SelectionDAG &DAG, SDNode *Root) - : DAG(DAG), RootMaskOp(), RootVectorLenOp() { + VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root) + : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() { assert(Root->isVPOpcode()); if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode())) RootMaskOp = Root->getOperand(*RootMaskPos); @@ -966,6 +974,12 @@ return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); } + + bool isOperationLegalOrCustom(unsigned Op, EVT VT, + bool LegalOnly = false) const { + unsigned VPOp = ISD::getVPForBaseOpcode(Op); + return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly); + } }; } // end anonymous namespace @@ -15048,7 +15062,7 @@ SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); - MatchContextClass matcher(DAG, N); + MatchContextClass matcher(DAG, TLI, N); const TargetOptions &Options = DAG.getTarget().Options; bool UseVP = std::is_same_v; @@ -15059,9 +15073,9 @@ bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N)); // Floating-point multiply-add without intermediate rounding. - unsigned FMAOpc = UseVP ? ISD::VP_FMA : ISD::FMA; - bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && - (!LegalOperations || TLI.isOperationLegalOrCustom(FMAOpc, VT)); + bool HasFMA = + TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && + (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA)