diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -491,8 +491,7 @@ SDValue visitSTRICT_FADD(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); - template - SDValue visitFMA(SDNode *N); + template SDValue visitFMA(SDNode *N); SDValue visitFDIV(SDNode *N); SDValue visitFREM(SDNode *N); SDValue visitFSQRT(SDNode *N); @@ -16317,8 +16316,7 @@ return SDValue(); } -template -SDValue DAGCombiner::visitFMA(SDNode *N) { +template SDValue DAGCombiner::visitFMA(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); @@ -16365,7 +16363,7 @@ return N2; } - // FIXME: Support vector constant patterns in the function. + // FIXME: Support splat of constant. if (N0CFP && N0CFP->isExactlyValue(1.0)) return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2); if (N1CFP && N1CFP->isExactlyValue(1.0)) @@ -16381,21 +16379,23 @@ if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) && DAG.isConstantFPBuildVectorOrConstantFP(N1) && DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) { - return matcher.getNode(ISD::FMUL, DL, VT, N0, - matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1))); + return matcher.getNode( + ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1))); } // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) if (matcher.match(N0, ISD::FMUL) && DAG.isConstantFPBuildVectorOrConstantFP(N1) && DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) { - return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0), - matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), - N2); + return matcher.getNode( + ISD::FMA, DL, VT, N0.getOperand(0), + matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2); } } // (fma x, -1, y) -> (fadd (fneg x), y) + // FIXME: Support splat of constant. if (N1CFP) { if (N1CFP->isExactlyValue(1.0)) return matcher.getNode(ISD::FADD, DL, VT, N0, N2); @@ -16410,26 +16410,27 @@ // fma (fneg x), K, y -> fma x -K, y if (matcher.match(N0, ISD::FNEG) && (TLI.isOperationLegal(ISD::ConstantFP, VT) || - (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, - ForCodeSize)))) { + (N1.hasOneUse() && + !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) { return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0), - matcher.getNode(ISD::FNEG, DL, VT, N1), N2); + matcher.getNode(ISD::FNEG, DL, VT, N1), N2); } } + // FIXME: Support splat of constant. if (CanReassociate) { // (fma x, c, x) -> (fmul x, (c+1)) if (N1CFP && N0 == N2) { - return matcher.getNode( - ISD::FMUL, DL, VT, N0, - matcher.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT))); + return matcher.getNode(ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, + DAG.getConstantFP(1.0, DL, VT))); } // (fma x, c, (fneg x)) -> (fmul x, (c-1)) if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) { - return matcher.getNode( - ISD::FMUL, DL, VT, N0, - matcher.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT))); + return matcher.getNode(ISD::FMUL, DL, VT, N0, + matcher.getNode(ISD::FADD, DL, VT, N1, + DAG.getConstantFP(-1.0, DL, VT))); } } diff --git a/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll b/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll @@ -0,0 +1,70 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=ilp32d \ +; RUN: -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=lp64d \ +; RUN: -verify-machineinstrs < %s | FileCheck %s + +declare @llvm.vp.fma.nxv1f64(, , , , i32) +declare @llvm.vp.fneg.nxv1f64(, , i32) +declare @llvm.vp.fmul.nxv1f64(, , , i32) + +; (-N0 * -N1) + N2 --> (N0 * N1) + N2 +define @test1( %a, %b, %c, %m, i32 zeroext %evl) { +; CHECK-LABEL: test1: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma +; CHECK-NEXT: vfmadd.vv v9, v8, v10, v0.t +; CHECK-NEXT: vmv.v.v v8, v9 +; CHECK-NEXT: ret + %nega = call @llvm.vp.fneg.nxv1f64( %a, %m, i32 %evl) + %negb = call @llvm.vp.fneg.nxv1f64( %b, %m, i32 %evl) + %v = call @llvm.vp.fma.nxv1f64( %nega, %negb, %c, %m, i32 %evl) + ret %v +} + +; (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) +define @test2( %a, %m, i32 zeroext %evl) { +; CHECK-LABEL: test2: +; CHECK: # %bb.0: +; CHECK-NEXT: lui a1, %hi(.LCPI1_0) +; CHECK-NEXT: addi a1, a1, %lo(.LCPI1_0) +; CHECK-NEXT: vsetvli a2, zero, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v9, (a1), zero +; CHECK-NEXT: lui a1, %hi(.LCPI1_1) +; CHECK-NEXT: fld fa5, %lo(.LCPI1_1)(a1) +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma +; CHECK-NEXT: vfadd.vf v9, v9, fa5, v0.t +; CHECK-NEXT: vfmul.vv v8, v8, v9, v0.t +; CHECK-NEXT: ret + %elt.head1 = insertelement poison, double 2.0, i32 0 + %c1 = shufflevector %elt.head1, poison, zeroinitializer + %t = call @llvm.vp.fmul.nxv1f64( %a, %c1, %m, i32 %evl) + %elt.head2 = insertelement poison, double 4.0, i32 0 + %c2 = shufflevector %elt.head2, poison, zeroinitializer + %v = call fast @llvm.vp.fma.nxv1f64( %a, %c2, %t, %m, i32 %evl) + ret %v +} + +; (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) +define @test3( %a, %b, %m, i32 zeroext %evl) { +; CHECK-LABEL: test3: +; CHECK: # %bb.0: +; CHECK-NEXT: lui a1, %hi(.LCPI2_0) +; CHECK-NEXT: addi a1, a1, %lo(.LCPI2_0) +; CHECK-NEXT: vsetvli a2, zero, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v10, (a1), zero +; CHECK-NEXT: lui a1, %hi(.LCPI2_1) +; CHECK-NEXT: fld fa5, %lo(.LCPI2_1)(a1) +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma +; CHECK-NEXT: vfmul.vf v10, v10, fa5, v0.t +; CHECK-NEXT: vfmadd.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv.v.v v8, v10 +; CHECK-NEXT: ret + %elt.head1 = insertelement poison, double 2.0, i32 0 + %c1 = shufflevector %elt.head1, poison, zeroinitializer + %t = call @llvm.vp.fmul.nxv1f64( %a, %c1, %m, i32 %evl) + %elt.head2 = insertelement poison, double 4.0, i32 0 + %c2 = shufflevector %elt.head2, poison, zeroinitializer + %v = call fast @llvm.vp.fma.nxv1f64( %t, %c2, %b, %m, i32 %evl) + ret %v +}