diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -485,6 +485,7 @@ SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); SDValue visitVP_FADD(SDNode *N); + SDValue visitVP_FSUB(SDNode *N); SDValue visitSTRICT_FADD(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); @@ -539,6 +540,7 @@ template SDValue visitFADDForFMACombine(SDNode *N); + template SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -15369,20 +15371,26 @@ } /// Try to perform FMA combining on a given FSUB node. +template SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); - + MatchContextClass matcher(DAG, TLI, N); const TargetOptions &Options = DAG.getTarget().Options; + + bool UseVP = std::is_same_v; + // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N)); + // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext. + // FIXME: Add VP_FMAD opcode. + bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N)); // Floating-point multiply-add without intermediate rounding. bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && - (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); + (!LegalOperations || matcher.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) @@ -15406,8 +15414,8 @@ // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || N->getFlags().hasAllowContract(); }; @@ -15415,8 +15423,9 @@ // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) auto tryToFoldXYSubZ = [&](SDValue XY, SDValue Z) { if (isContractableFMUL(XY) && (Aggressive || XY->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0), - XY.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, Z)); + return matcher.getNode(PreferredFusedOpcode, SL, VT, XY.getOperand(0), + XY.getOperand(1), + matcher.getNode(ISD::FNEG, SL, VT, Z)); } return SDValue(); }; @@ -15425,9 +15434,10 @@ // Note: Commutes FSUB operands. auto tryToFoldXSubYZ = [&](SDValue X, SDValue YZ) { if (isContractableFMUL(YZ) && (Aggressive || YZ->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)), - YZ.getOperand(1), X); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, YZ.getOperand(0)), + YZ.getOperand(1), X); } return SDValue(); }; @@ -15452,44 +15462,46 @@ } // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z)) - if (N0.getOpcode() == ISD::FNEG && isContractableFMUL(N0.getOperand(0)) && + if (matcher.match(N0, ISD::FNEG) && isContractableFMUL(N0.getOperand(0)) && (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) { SDValue N00 = N0.getOperand(0).getOperand(0); SDValue N01 = N0.getOperand(0).getOperand(1); - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N00), N01, - DAG.getNode(ISD::FNEG, SL, VT, N1)); + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, N00), N01, + matcher.getNode(ISD::FNEG, SL, VT, N1)); } // Look through FP_EXTEND nodes to do more combining. // fold (fsub (fpext (fmul x, y)), z) // -> (fma (fpext x), (fpext y), (fneg z)) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, N1)); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), + matcher.getNode(ISD::FNEG, SL, VT, N1)); } } // fold (fsub x, (fpext (fmul y, z))) // -> (fma (fneg (fpext y)), (fpext z), x) // Note: Commutes FSUB operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); + matcher.getNode( + ISD::FNEG, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0))), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0); } } @@ -15499,19 +15511,20 @@ // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent // from implementing the canonicalization in visitFSUB. - if (N0.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FNEG) { + if (matcher.match(N00, ISD::FNEG)) { SDValue N000 = N00.getOperand(0); if (isContractableFMUL(N000) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode( + return matcher.getNode( ISD::FNEG, SL, VT, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), - N1)); + matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), + N1)); } } } @@ -15522,19 +15535,20 @@ // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent // from implementing the canonicalization in visitFSUB. - if (N0.getOpcode() == ISD::FNEG) { + if (matcher.match(N0, ISD::FNEG)) { SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N00, ISD::FP_EXTEND)) { SDValue N000 = N00.getOperand(0); if (isContractableFMUL(N000) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N000.getValueType())) { - return DAG.getNode( + return matcher.getNode( ISD::FNEG, SL, VT, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), - N1)); + matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N000.getOperand(1)), + N1)); } } } @@ -15549,8 +15563,7 @@ }; auto isFusedOp = [&](SDValue N) { - unsigned Opcode = N.getOpcode(); - return Opcode == ISD::FMA || Opcode == ISD::FMAD; + return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD); }; // More folding opportunities when target permits. @@ -15561,12 +15574,12 @@ if (CanFuse && isFusedOp(N0) && isContractableAndReassociableFMUL(N0.getOperand(2)) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), - N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - N0.getOperand(2).getOperand(0), - N0.getOperand(2).getOperand(1), - DAG.getNode(ISD::FNEG, SL, VT, N1))); + return matcher.getNode( + PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), + matcher.getNode(PreferredFusedOpcode, SL, VT, + N0.getOperand(2).getOperand(0), + N0.getOperand(2).getOperand(1), + matcher.getNode(ISD::FNEG, SL, VT, N1))); } // fold (fsub x, (fma y, z, (fmul u, v))) @@ -15576,29 +15589,30 @@ N1->hasOneUse() && NoSignedZero) { SDValue N20 = N1.getOperand(2).getOperand(0); SDValue N21 = N1.getOperand(2).getOperand(1); - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N20), N21, N0)); + matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + N1.getOperand(1), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, N20), N21, N0)); } // fold (fsub (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) if (isFusedOp(N0) && N0->hasOneUse()) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableAndReassociableFMUL(N020) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N020.getValueType())) { - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode( + matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, N1))); + matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N020.getOperand(1)), + matcher.getNode(ISD::FNEG, SL, VT, N1))); } } } @@ -15609,29 +15623,29 @@ // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - if (N0.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N0, ISD::FP_EXTEND)) { SDValue N00 = N0.getOperand(0); if (isFusedOp(N00)) { SDValue N002 = N00.getOperand(2); if (isContractableAndReassociableFMUL(N002) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), - DAG.getNode( + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), + matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)), - DAG.getNode(ISD::FNEG, SL, VT, N1))); + matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(0)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N002.getOperand(1)), + matcher.getNode(ISD::FNEG, SL, VT, N1))); } } } // fold (fsub x, (fma y, z, (fpext (fmul u, v)))) // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x)) - if (isFusedOp(N1) && N1.getOperand(2).getOpcode() == ISD::FP_EXTEND && + if (isFusedOp(N1) && matcher.match(N1.getOperand(2), ISD::FP_EXTEND) && N1->hasOneUse()) { SDValue N120 = N1.getOperand(2).getOperand(0); if (isContractableAndReassociableFMUL(N120) && @@ -15639,13 +15653,15 @@ N120.getValueType())) { SDValue N1200 = N120.getOperand(0); SDValue N1201 = N120.getOperand(1); - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N1200)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0)); + matcher.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + N1.getOperand(1), + matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N1200)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N1201), N0)); } } @@ -15655,7 +15671,7 @@ // FIXME: This turns two single-precision and one double-precision // operation into two double-precision operations, which might not be // interesting for all targets, especially GPUs. - if (N1.getOpcode() == ISD::FP_EXTEND && isFusedOp(N1.getOperand(0))) { + if (matcher.match(N1, ISD::FP_EXTEND) && isFusedOp(N1.getOperand(0))) { SDValue CvtSrc = N1.getOperand(0); SDValue N100 = CvtSrc.getOperand(0); SDValue N101 = CvtSrc.getOperand(1); @@ -15665,15 +15681,16 @@ CvtSrc.getValueType())) { SDValue N1020 = N102.getOperand(0); SDValue N1021 = N102.getOperand(1); - return DAG.getNode( + return matcher.getNode( PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N100)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N101), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FNEG, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, N1020)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0)); + matcher.getNode(ISD::FNEG, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N100)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N101), + matcher.getNode( + PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FNEG, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N1020)), + matcher.getNode(ISD::FP_EXTEND, SL, VT, N1021), N0)); } } } @@ -16082,7 +16099,7 @@ return DAG.getNode(ISD::FADD, DL, VT, N0, NegN1); // FSUB -> FMA combines: - if (SDValue Fused = visitFSUBForFMACombine(N)) { + if (SDValue Fused = visitFSUBForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -25476,6 +25493,17 @@ return SDValue(); } +SDValue DAGCombiner::visitVP_FSUB(SDNode *N) { + SelectionDAG::FlagInserter FlagsInserter(DAG, N); + + // FSUB -> FMA combines: + if (SDValue Fused = visitFSUBForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitVPOp(SDNode *N) { if (N->getOpcode() == ISD::VP_GATHER) @@ -25501,6 +25529,8 @@ switch (N->getOpcode()) { case ISD::VP_FADD: return visitVP_FADD(N); + case ISD::VP_FSUB: + return visitVP_FSUB(N); } return SDValue(); } diff --git a/llvm/test/CodeGen/RISCV/rvv/fold-vp-fsub-and-vp-fmul.ll b/llvm/test/CodeGen/RISCV/rvv/fold-vp-fsub-and-vp-fmul.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/fold-vp-fsub-and-vp-fmul.ll @@ -0,0 +1,46 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -mattr=+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s + +declare @llvm.vp.fmul.nxv1f64( %x, %y, %m, i32 %vl) +declare @llvm.vp.fsub.nxv1f64( %x, %y, %m, i32 %vl) +declare @llvm.vp.fneg.nxv1f64( %x, %m, i32 %vl) + +; (fsub (fmul x, y), z)) -> (fma x, y, (fneg z)) +define @test1( %x, %y, %z, %m, i32 zeroext %vl) { +; CHECK-LABEL: test1: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, mu +; CHECK-NEXT: vfmsub.vv v9, v8, v10, v0.t +; CHECK-NEXT: vmv.v.v v8, v9 +; CHECK-NEXT: ret + %1 = call fast @llvm.vp.fmul.nxv1f64( %x, %y, %m, i32 %vl) + %2 = call fast @llvm.vp.fsub.nxv1f64( %1, %z, %m, i32 %vl) + ret %2 +} + +; (fsub z, (fmul x, y))) -> (fma (fneg y), x, z) +define @test2( %x, %y, %z, %m, i32 zeroext %vl) { +; CHECK-LABEL: test2: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, mu +; CHECK-NEXT: vfnmsub.vv v9, v8, v10, v0.t +; CHECK-NEXT: vmv.v.v v8, v9 +; CHECK-NEXT: ret + %1 = call fast @llvm.vp.fmul.nxv1f64( %x, %y, %m, i32 %vl) + %2 = call fast @llvm.vp.fsub.nxv1f64( %z, %1, %m, i32 %vl) + ret %2 +} + +; (fsub (fneg (fmul x, y))), z) -> (fma (fneg x), y, (fneg z)) +define @test3( %x, %y, %z, %m, i32 zeroext %vl) { +; CHECK-LABEL: test3: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, mu +; CHECK-NEXT: vfmsub.vv v9, v8, v10, v0.t +; CHECK-NEXT: vmv.v.v v8, v9 +; CHECK-NEXT: ret + %1 = call fast @llvm.vp.fmul.nxv1f64( %x, %y, %m, i32 %vl) + %2 = call fast @llvm.vp.fneg.nxv1f64( %1, %m, i32 %vl) + %3 = call fast @llvm.vp.fsub.nxv1f64( %1, %z, %m, i32 %vl) + ret %3 +}