diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -517,6 +517,8 @@ SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); + SDValue visitVPFADDForVPFMACombine(SDNode *N); + SDValue XformToShuffleWithZero(SDNode *N); bool reassociationCanBreakAddressingModePattern(unsigned Opc, const SDLoc &DL, @@ -23030,6 +23032,49 @@ return SDValue(); } +/// Try to perform VP_FMA combining on a given VP_FADD node. +SDValue DAGCombiner::visitVPFADDForVPFMACombine(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDValue Mask = N->getOperand(2); + SDValue VL = N->getOperand(3); + EVT VT = N->getValueType(0); + SDLoc SL(N); + + const TargetOptions &Options = DAG.getTarget().Options; + + bool HasFMA = + TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::VP_FMA, VT)); + + if (!HasFMA) + return SDValue(); + + bool AllowFusionGlobally = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); + + // Is the node an VP_FMUL and contractable either due to global flags or + // SDNodeFlags. + auto isContractableVPFMUL = [AllowFusionGlobally](SDValue N) { + if (N.getOpcode() != ISD::VP_FMUL) + return false; + return AllowFusionGlobally || N->getFlags().hasAllowContract(); + }; + + // fold (vp_fadd (vp_fmul x, y), z) -> (vp_fma x, y, z) + if (isContractableVPFMUL(N0) && N0->hasOneUse()) + return DAG.getNode(ISD::VP_FMA, SL, VT, N0.getOperand(0), N0.getOperand(1), + N1, Mask, VL); + + // fold (vp_fadd x, (vp_fmul y, z)) -> (vp_fma y, z, x) + // Note: Commutes VP_FADD operands. + if (isContractableVPFMUL(N1) && N1->hasOneUse()) + return DAG.getNode(ISD::VP_FMA, SL, VT, N1.getOperand(0), N1.getOperand(1), + N0, Mask, VL); + + return SDValue(); +} + SDValue DAGCombiner::visitVPOp(SDNode *N) { // VP operations in which all vector elements are disabled - either by // determining that the mask is all false or that the EVL is 0 - can be @@ -23042,8 +23087,13 @@ ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode()); // This is the only generic VP combine we support for now. - if (!AreAllEltsDisabled) + if (!AreAllEltsDisabled) { + switch (N->getOpcode()) { + case ISD::VP_FADD: + return visitVPFADDForVPFMACombine(N); + } return SDValue(); + } // Binary operations can be replaced by UNDEF. if (ISD::isVPBinaryOp(N->getOpcode())) diff --git a/llvm/test/CodeGen/RISCV/rvv/fold-fadd-and-fmul.ll b/llvm/test/CodeGen/RISCV/rvv/fold-fadd-and-fmul.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/fold-fadd-and-fmul.ll @@ -0,0 +1,33 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -mattr=+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s + +declare @llvm.vp.fmul.nxv1f64( %x, %y, %m, i32 %vl) +declare @llvm.vp.fadd.nxv1f64( %x, %y, %m, i32 %vl) + +define @test1( %x, %y, %z, %m, i32 %vl) { +; CHECK-LABEL: test1: +; CHECK: # %bb.0: +; CHECK-NEXT: slli a0, a0, 32 +; CHECK-NEXT: srli a0, a0, 32 +; CHECK-NEXT: vsetvli zero, a0, e64, m1, tu, mu +; CHECK-NEXT: vfmadd.vv v9, v8, v10, v0.t +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %1 = call fast @llvm.vp.fmul.nxv1f64( %x, %y, %m, i32 %vl) + %2 = call fast @llvm.vp.fadd.nxv1f64( %1, %z, %m, i32 %vl) + ret %2 +} + +define @test2( %x, %y, %z, %m, i32 %vl) { +; CHECK-LABEL: test2: +; CHECK: # %bb.0: +; CHECK-NEXT: slli a0, a0, 32 +; CHECK-NEXT: srli a0, a0, 32 +; CHECK-NEXT: vsetvli zero, a0, e64, m1, tu, mu +; CHECK-NEXT: vfmadd.vv v9, v8, v10, v0.t +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %1 = call fast @llvm.vp.fmul.nxv1f64( %x, %y, %m, i32 %vl) + %2 = call fast @llvm.vp.fadd.nxv1f64( %z, %1, %m, i32 %vl) + ret %2 +}