diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -491,7 +491,7 @@ SDValue visitSTRICT_FADD(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); - SDValue visitFMA(SDNode *N); + template SDValue visitFMA(SDNode *N); SDValue visitFDIV(SDNode *N); SDValue visitFREM(SDNode *N); SDValue visitFSQRT(SDNode *N); @@ -1961,7 +1961,7 @@ case ISD::STRICT_FADD: return visitSTRICT_FADD(N); case ISD::FSUB: return visitFSUB(N); case ISD::FMUL: return visitFMUL(N); - case ISD::FMA: return visitFMA(N); + case ISD::FMA: return visitFMA(N); case ISD::FDIV: return visitFDIV(N); case ISD::FREM: return visitFREM(N); case ISD::FSQRT: return visitFSQRT(N); @@ -16320,7 +16320,7 @@ return SDValue(); } -SDValue DAGCombiner::visitFMA(SDNode *N) { +template SDValue DAGCombiner::visitFMA(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); @@ -16331,6 +16331,7 @@ const TargetOptions &Options = DAG.getTarget().Options; // FMA nodes have flags that propagate to the created nodes. SelectionDAG::FlagInserter FlagsInserter(DAG, N); + MatchContextClass matcher(DAG, TLI, N); bool CanReassociate = Options.UnsafeFPMath || N->getFlags().hasAllowReassociation(); @@ -16339,7 +16340,7 @@ if (isa(N0) && isa(N1) && isa(N2)) { - return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2); + return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2); } // (-N0 * -N1) + N2 --> (N0 * N1) + N2 @@ -16355,7 +16356,7 @@ TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1); if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper || CostN1 == TargetLowering::NegatibleCost::Cheaper)) - return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2); + return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2); } // FIXME: use fast math flags instead of Options.UnsafeFPMath @@ -16366,70 +16367,74 @@ return N2; } + // FIXME: Support splat of constant. if (N0CFP && N0CFP->isExactlyValue(1.0)) - return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2); + return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2); if (N1CFP && N1CFP->isExactlyValue(1.0)) - return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2); + return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2); // Canonicalize (fma c, x, y) -> (fma x, c, y) if (DAG.isConstantFPBuildVectorOrConstantFP(N0) && !DAG.isConstantFPBuildVectorOrConstantFP(N1)) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2); + return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2); if (CanReassociate) { // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) - if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) && + if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) && DAG.isConstantFPBuildVectorOrConstantFP(N1) && DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) { - return DAG.getNode(ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1))); + return matcher.getNode( + ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1))); } // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) - if (N0.getOpcode() == ISD::FMUL && + if (matcher.match(N0, ISD::FMUL) && DAG.isConstantFPBuildVectorOrConstantFP(N1) && DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) { - return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0), - DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), - N2); + return matcher.getNode( + ISD::FMA, DL, VT, N0.getOperand(0), + matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2); } } // (fma x, -1, y) -> (fadd (fneg x), y) + // FIXME: Support splat of constant. if (N1CFP) { if (N1CFP->isExactlyValue(1.0)) - return DAG.getNode(ISD::FADD, DL, VT, N0, N2); + return matcher.getNode(ISD::FADD, DL, VT, N0, N2); if (N1CFP->isExactlyValue(-1.0) && (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) { - SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0); + SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0); AddToWorklist(RHSNeg.getNode()); - return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg); + return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg); } // fma (fneg x), K, y -> fma x -K, y - if (N0.getOpcode() == ISD::FNEG && + if (matcher.match(N0, ISD::FNEG) && (TLI.isOperationLegal(ISD::ConstantFP, VT) || - (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, - ForCodeSize)))) { - return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0), - DAG.getNode(ISD::FNEG, DL, VT, N1), N2); + (N1.hasOneUse() && + !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) { + return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0), + matcher.getNode(ISD::FNEG, DL, VT, N1), N2); } } + // FIXME: Support splat of constant. if (CanReassociate) { // (fma x, c, x) -> (fmul x, (c+1)) if (N1CFP && N0 == N2) { - return DAG.getNode( - ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT))); + return matcher.getNode(ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, + DAG.getConstantFP(1.0, DL, VT))); } // (fma x, c, (fneg x)) -> (fmul x, (c-1)) - if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) { - return DAG.getNode( - ISD::FMUL, DL, VT, N0, - DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT))); + if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) { + return matcher.getNode(ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, + DAG.getConstantFP(-1.0, DL, VT))); } } @@ -16438,7 +16443,7 @@ if (!TLI.isFNegFree(VT)) if (SDValue Neg = TLI.getCheaperNegatedExpression( SDValue(N, 0), DAG, LegalOperations, ForCodeSize)) - return DAG.getNode(ISD::FNEG, DL, VT, Neg); + return matcher.getNode(ISD::FNEG, DL, VT, Neg); return SDValue(); } @@ -25695,6 +25700,8 @@ return visitVP_FADD(N); case ISD::VP_FSUB: return visitVP_FSUB(N); + case ISD::VP_FMA: + return visitFMA(N); } return SDValue(); } diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -6816,7 +6816,7 @@ NegatibleCost &Cost, unsigned Depth) const { // fneg is removable even if it has multiple uses. - if (Op.getOpcode() == ISD::FNEG) { + if (Op.getOpcode() == ISD::FNEG || Op.getOpcode() == ISD::VP_FNEG) { Cost = NegatibleCost::Cheaper; return Op.getOperand(0); } diff --git a/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll b/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll @@ -0,0 +1,70 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=ilp32d \ +; RUN: -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=lp64d \ +; RUN: -verify-machineinstrs < %s | FileCheck %s + +declare @llvm.vp.fma.nxv1f64(, , , , i32) +declare @llvm.vp.fneg.nxv1f64(, , i32) +declare @llvm.vp.fmul.nxv1f64(, , , i32) + +; (-N0 * -N1) + N2 --> (N0 * N1) + N2 +define @test1( %a, %b, %c, %m, i32 zeroext %evl) { +; CHECK-LABEL: test1: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma +; CHECK-NEXT: vfmadd.vv v9, v8, v10, v0.t +; CHECK-NEXT: vmv.v.v v8, v9 +; CHECK-NEXT: ret + %nega = call @llvm.vp.fneg.nxv1f64( %a, %m, i32 %evl) + %negb = call @llvm.vp.fneg.nxv1f64( %b, %m, i32 %evl) + %v = call @llvm.vp.fma.nxv1f64( %nega, %negb, %c, %m, i32 %evl) + ret %v +} + +; (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) +define @test2( %a, %m, i32 zeroext %evl) { +; CHECK-LABEL: test2: +; CHECK: # %bb.0: +; CHECK-NEXT: lui a1, %hi(.LCPI1_0) +; CHECK-NEXT: addi a1, a1, %lo(.LCPI1_0) +; CHECK-NEXT: vsetvli a2, zero, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v9, (a1), zero +; CHECK-NEXT: lui a1, %hi(.LCPI1_1) +; CHECK-NEXT: fld fa5, %lo(.LCPI1_1)(a1) +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma +; CHECK-NEXT: vfadd.vf v9, v9, fa5, v0.t +; CHECK-NEXT: vfmul.vv v8, v8, v9, v0.t +; CHECK-NEXT: ret + %elt.head1 = insertelement poison, double 2.0, i32 0 + %c1 = shufflevector %elt.head1, poison, zeroinitializer + %t = call @llvm.vp.fmul.nxv1f64( %a, %c1, %m, i32 %evl) + %elt.head2 = insertelement poison, double 4.0, i32 0 + %c2 = shufflevector %elt.head2, poison, zeroinitializer + %v = call fast @llvm.vp.fma.nxv1f64( %a, %c2, %t, %m, i32 %evl) + ret %v +} + +; (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) +define @test3( %a, %b, %m, i32 zeroext %evl) { +; CHECK-LABEL: test3: +; CHECK: # %bb.0: +; CHECK-NEXT: lui a1, %hi(.LCPI2_0) +; CHECK-NEXT: addi a1, a1, %lo(.LCPI2_0) +; CHECK-NEXT: vsetvli a2, zero, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v10, (a1), zero +; CHECK-NEXT: lui a1, %hi(.LCPI2_1) +; CHECK-NEXT: fld fa5, %lo(.LCPI2_1)(a1) +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma +; CHECK-NEXT: vfmul.vf v10, v10, fa5, v0.t +; CHECK-NEXT: vfmadd.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv.v.v v8, v10 +; CHECK-NEXT: ret + %elt.head1 = insertelement poison, double 2.0, i32 0 + %c1 = shufflevector %elt.head1, poison, zeroinitializer + %t = call @llvm.vp.fmul.nxv1f64( %a, %c1, %m, i32 %evl) + %elt.head2 = insertelement poison, double 4.0, i32 0 + %c2 = shufflevector %elt.head2, poison, zeroinitializer + %v = call fast @llvm.vp.fma.nxv1f64( %t, %c2, %b, %m, i32 %evl) + ret %v +}