diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h --- a/llvm/include/llvm/IR/Operator.h +++ b/llvm/include/llvm/IR/Operator.h @@ -20,6 +20,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/FMF.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" @@ -300,6 +301,11 @@ else return false; + // Pass this VPIntrinsic as its underlying operation. + if (auto *VPI = dyn_cast(V)) + if (Optional OpcodeOpt = VPI->getFunctionalOpcode()) + Opcode = OpcodeOpt.getValue(); + switch (Opcode) { case Instruction::FNeg: case Instruction::FAdd: diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8808,8 +8808,10 @@ AddNodeIDNode(ID, Opcode, VTs, Ops); void *IP = nullptr; - if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + E->intersectFlagsWith(Flags); return SDValue(E, 0); + } N = newSDNode(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs); createOperands(N, Ops); @@ -8843,7 +8845,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList, ArrayRef Ops, const SDNodeFlags Flags) { if (VTList.NumVTs == 1) - return getNode(Opcode, DL, VTList.VTs[0], Ops); + return getNode(Opcode, DL, VTList.VTs[0], Ops, Flags); #ifndef NDEBUG for (auto &Op : Ops) @@ -8912,8 +8914,10 @@ FoldingSetNodeID ID; AddNodeIDNode(ID, Opcode, VTList, Ops); void *IP = nullptr; - if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + E->intersectFlagsWith(Flags); return SDValue(E, 0); + } N = newSDNode(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList); createOperands(N, Ops); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -7521,7 +7521,13 @@ switch (Opcode) { default: { - SDValue Result = DAG.getNode(Opcode, DL, VTs, OpValues); + SDNodeFlags NodeFlags; + if (auto *FPIntrin = dyn_cast(&VPIntrin)) { + NodeFlags.copyFMF(*FPIntrin); + NodeFlags.setNoFPExcept(true); // TODO: Constrained VP + } + SDValue Result = DAG.getNode(Opcode, DL, VTs, OpValues, NodeFlags); + // Transfer FMF and exception flags. setValue(&VPIntrin, Result); break; } diff --git a/llvm/lib/Target/VE/CMakeLists.txt b/llvm/lib/Target/VE/CMakeLists.txt --- a/llvm/lib/Target/VE/CMakeLists.txt +++ b/llvm/lib/Target/VE/CMakeLists.txt @@ -26,6 +26,7 @@ VERegisterInfo.cpp VESubtarget.cpp VETargetMachine.cpp + VVPCombine.cpp VVPISelLowering.cpp LINK_COMPONENTS diff --git a/llvm/lib/Target/VE/VEISelLowering.h b/llvm/lib/Target/VE/VEISelLowering.h --- a/llvm/lib/Target/VE/VEISelLowering.h +++ b/llvm/lib/Target/VE/VEISelLowering.h @@ -199,6 +199,7 @@ /// Custom DAGCombine { SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override; + SDValue combineVVP(SDNode *N, DAGCombinerInfo &DCI) const; SDValue combineTRUNCATE(SDNode *N, DAGCombinerInfo &DCI) const; /// } Custom DAGCombine diff --git a/llvm/lib/Target/VE/VEISelLowering.cpp b/llvm/lib/Target/VE/VEISelLowering.cpp --- a/llvm/lib/Target/VE/VEISelLowering.cpp +++ b/llvm/lib/Target/VE/VEISelLowering.cpp @@ -2663,6 +2663,8 @@ DAGCombinerInfo &DCI) const { switch (N->getOpcode()) { default: + if (isVVPOrVEC(N->getOpcode())) + return combineVVP(N, DCI); break; case ISD::TRUNCATE: return combineTRUNCATE(N, DCI); diff --git a/llvm/lib/Target/VE/VVPCombine.cpp b/llvm/lib/Target/VE/VVPCombine.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/VE/VVPCombine.cpp @@ -0,0 +1,60 @@ +#include "VECustomDAG.h" + +#ifdef DEBUG_TYPE +#undef DEBUG_TYPE +#endif +#define DEBUG_TYPE "vvp-combine" + +using namespace llvm; + +using Matcher = std::function; + +static Optional match_SomeOperand(SDNode *Op, unsigned VVPOpcode, + Matcher M) { + for (unsigned i = 0; i < 2; ++i) { + if ((Op->getOperand(i)->getOpcode() == VVPOpcode) && M(Op->getOperand(i))) + return i; + } + return None; +} + +// vz * vw + vy +static bool match_FFMA(SDNode *Root, SDValue &VY, SDValue &VZ, SDValue &VW, + SDValue &Mask, SDValue &AVL) { + if (Root->getOpcode() != VEISD::VVP_FADD) + return false; + + // Detect contractable FMUL leaf. + auto MulIdx = match_SomeOperand(Root, VEISD::VVP_FMUL, [](SDValue Op) { + return Op->hasOneUse() && Op->getFlags().hasAllowContract(); + }); + if (!MulIdx.hasValue()) + return false; + assert(MulIdx.getValue() < 2); + const int LeafIdx = 1 - MulIdx.getValue(); + + // Take apart. + SDValue MulV = Root->getOperand(*MulIdx); + VY = Root->getOperand(LeafIdx); + VZ = MulV->getOperand(0); + VW = MulV->getOperand(1); + Mask = Root->getOperand(2); + AVL = Root->getOperand(3); + return true; +} + +SDValue VETargetLowering::combineVVP(SDNode *N, DAGCombinerInfo &DCI) const { + VECustomDAG CDAG(DCI.DAG, N); + SDNodeFlags Flags = N->getFlags(); + MVT ResVT = N->getSimpleValueType(0); + switch (N->getOpcode()) { + // Fuse FMA, FMSB, FNMA, FNMSB, .. + case VEISD::VVP_FADD: { + SDValue VY, VZ, VW, Mask, AVL; + if (match_FFMA(N, VY, VZ, VW, Mask, AVL)) + return CDAG.getNode(VEISD::VVP_FFMA, ResVT, {VY, VZ, VW, Mask, AVL}, + Flags); + } break; + } + return SDValue(); +} diff --git a/llvm/lib/Target/VE/VVPISelLowering.cpp b/llvm/lib/Target/VE/VVPISelLowering.cpp --- a/llvm/lib/Target/VE/VVPISelLowering.cpp +++ b/llvm/lib/Target/VE/VVPISelLowering.cpp @@ -84,7 +84,8 @@ return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL}); if (isVVPBinaryOp(VVPOpcode)) return CDAG.getNode(VVPOpcode, LegalVecVT, - {Op->getOperand(0), Op->getOperand(1), Mask, AVL}); + {Op->getOperand(0), Op->getOperand(1), Mask, AVL}, + Op->getFlags()); if (isVVPReductionOp(VVPOpcode)) { auto SrcHasStart = hasReductionStartParam(Op->getOpcode()); SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue(); diff --git a/llvm/test/CodeGen/VE/Vector/fuse_vp_fma.ll b/llvm/test/CodeGen/VE/Vector/fuse_vp_fma.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/VE/Vector/fuse_vp_fma.ll @@ -0,0 +1,118 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -march=ve -mattr=+vpu | FileCheck %s + +define fastcc <256 x float> @test_vp_ffma_vvv_256f32(<256 x float> %i0, <256 x float> %i1, <256 x float> %i2, <256 x i1> %m, i32 %n) { +; CHECK-LABEL: test_vp_ffma_vvv_256f32: +; CHECK: # %bb.0: +; CHECK-NEXT: and %s0, %s0, (32)0 +; CHECK-NEXT: lvl %s0 +; CHECK-NEXT: vfmad.s %v0, %v2, %v0, %v1, %vm1 +; CHECK-NEXT: b.l.t (, %s10) + %mul = call contract <256 x float> @llvm.vp.fmul.v256f32(<256 x float> %i0, <256 x float> %i1, <256 x i1> %m, i32 %n) + %fma = call contract <256 x float> @llvm.vp.fadd.v256f32(<256 x float> %mul, <256 x float> %i2, <256 x i1> %m, i32 %n) + ret <256 x float> %fma +} + +define fastcc <256 x float> @test_vp_ffma_rvv_256f32(float %s0, <256 x float> %i1, <256 x float> %i2, <256 x i1> %m, i32 %n) { +; CHECK-LABEL: test_vp_ffma_rvv_256f32: +; CHECK: # %bb.0: +; CHECK-NEXT: and %s1, %s1, (32)0 +; CHECK-NEXT: lvl %s1 +; CHECK-NEXT: vfmad.s %v0, %v1, %s0, %v0, %vm1 +; CHECK-NEXT: b.l.t (, %s10) + %b0 = insertelement <256 x float> undef, float %s0, i32 0 + %i0 = shufflevector <256 x float> %b0, <256 x float> poison, <256 x i32> zeroinitializer + %mul = call contract <256 x float> @llvm.vp.fmul.v256f32(<256 x float> %i0, <256 x float> %i1, <256 x i1> %m, i32 %n) + %fma = call contract <256 x float> @llvm.vp.fadd.v256f32(<256 x float> %mul, <256 x float> %i2, <256 x i1> %m, i32 %n) + ret <256 x float> %fma +} + +define fastcc <256 x float> @test_vp_ffma_vrv_256f32(<256 x float> %i0, float %s1, <256 x float> %i2, <256 x i1> %m, i32 %n) { +; CHECK-LABEL: test_vp_ffma_vrv_256f32: +; CHECK: # %bb.0: +; CHECK-NEXT: and %s1, %s1, (32)0 +; CHECK-NEXT: lvl %s1 +; CHECK-NEXT: vfmad.s %v0, %v1, %s0, %v0, %vm1 +; CHECK-NEXT: b.l.t (, %s10) + %b1 = insertelement <256 x float> undef, float %s1, i32 0 + %i1 = shufflevector <256 x float> %b1, <256 x float> poison, <256 x i32> zeroinitializer + %mul = call contract <256 x float> @llvm.vp.fmul.v256f32(<256 x float> %i0, <256 x float> %i1, <256 x i1> %m, i32 %n) + %fma = call contract <256 x float> @llvm.vp.fadd.v256f32(<256 x float> %mul, <256 x float> %i2, <256 x i1> %m, i32 %n) + ret <256 x float> %fma +} + +define fastcc <256 x float> @test_vp_ffma_vvr_256f32(<256 x float> %i0, <256 x float> %i1, float %s2, <256 x i1> %m, i32 %n) { +; CHECK-LABEL: test_vp_ffma_vvr_256f32: +; CHECK: # %bb.0: +; CHECK-NEXT: and %s1, %s1, (32)0 +; CHECK-NEXT: lvl %s1 +; CHECK-NEXT: vfmad.s %v0, %s0, %v0, %v1, %vm1 +; CHECK-NEXT: b.l.t (, %s10) + %b2 = insertelement <256 x float> undef, float %s2, i32 0 + %i2 = shufflevector <256 x float> %b2, <256 x float> poison, <256 x i32> zeroinitializer + %mul = call contract <256 x float> @llvm.vp.fmul.v256f32(<256 x float> %i0, <256 x float> %i1, <256 x i1> %m, i32 %n) + %fma = call contract <256 x float> @llvm.vp.fadd.v256f32(<256 x float> %mul, <256 x float> %i2, <256 x i1> %m, i32 %n) + ret <256 x float> %fma +} + +declare <256 x float> @llvm.vp.fadd.v256f32(<256 x float>, <256 x float>, <256 x i1>, i32) +declare <256 x float> @llvm.vp.fmul.v256f32(<256 x float>, <256 x float>, <256 x i1>, i32) + +;;; 256 x double + +define fastcc <256 x double> @test_vp_ffma_vvv_256f64(<256 x double> %i0, <256 x double> %i1, <256 x double> %i2, <256 x i1> %m, i32 %n) { +; CHECK-LABEL: test_vp_ffma_vvv_256f64: +; CHECK: # %bb.0: +; CHECK-NEXT: and %s0, %s0, (32)0 +; CHECK-NEXT: lvl %s0 +; CHECK-NEXT: vfmad.d %v0, %v2, %v0, %v1, %vm1 +; CHECK-NEXT: b.l.t (, %s10) + %mul = call contract <256 x double> @llvm.vp.fmul.v256f64(<256 x double> %i0, <256 x double> %i1, <256 x i1> %m, i32 %n) + %fma = call contract <256 x double> @llvm.vp.fadd.v256f64(<256 x double> %mul, <256 x double> %i2, <256 x i1> %m, i32 %n) + ret <256 x double> %fma +} + +define fastcc <256 x double> @test_vp_ffma_rvv_256f64(double %s0, <256 x double> %i1, <256 x double> %i2, <256 x i1> %m, i32 %n) { +; CHECK-LABEL: test_vp_ffma_rvv_256f64: +; CHECK: # %bb.0: +; CHECK-NEXT: and %s1, %s1, (32)0 +; CHECK-NEXT: lvl %s1 +; CHECK-NEXT: vfmad.d %v0, %v1, %s0, %v0, %vm1 +; CHECK-NEXT: b.l.t (, %s10) + %b0 = insertelement <256 x double> undef, double %s0, i32 0 + %i0 = shufflevector <256 x double> %b0, <256 x double> poison, <256 x i32> zeroinitializer + %mul = call contract <256 x double> @llvm.vp.fmul.v256f64(<256 x double> %i0, <256 x double> %i1, <256 x i1> %m, i32 %n) + %fma = call contract <256 x double> @llvm.vp.fadd.v256f64(<256 x double> %mul, <256 x double> %i2, <256 x i1> %m, i32 %n) + ret <256 x double> %fma +} + +define fastcc <256 x double> @test_vp_ffma_vrv_256f64(<256 x double> %i0, double %s1, <256 x double> %i2, <256 x i1> %m, i32 %n) { +; CHECK-LABEL: test_vp_ffma_vrv_256f64: +; CHECK: # %bb.0: +; CHECK-NEXT: and %s1, %s1, (32)0 +; CHECK-NEXT: lvl %s1 +; CHECK-NEXT: vfmad.d %v0, %v1, %s0, %v0, %vm1 +; CHECK-NEXT: b.l.t (, %s10) + %b1 = insertelement <256 x double> undef, double %s1, i32 0 + %i1 = shufflevector <256 x double> %b1, <256 x double> poison, <256 x i32> zeroinitializer + %mul = call contract <256 x double> @llvm.vp.fmul.v256f64(<256 x double> %i0, <256 x double> %i1, <256 x i1> %m, i32 %n) + %fma = call contract <256 x double> @llvm.vp.fadd.v256f64(<256 x double> %mul, <256 x double> %i2, <256 x i1> %m, i32 %n) + ret <256 x double> %fma +} + +define fastcc <256 x double> @test_vp_ffma_vvr_256f64(<256 x double> %i0, <256 x double> %i1, double %s2, <256 x i1> %m, i32 %n) { +; CHECK-LABEL: test_vp_ffma_vvr_256f64: +; CHECK: # %bb.0: +; CHECK-NEXT: and %s1, %s1, (32)0 +; CHECK-NEXT: lvl %s1 +; CHECK-NEXT: vfmad.d %v0, %s0, %v0, %v1, %vm1 +; CHECK-NEXT: b.l.t (, %s10) + %b2 = insertelement <256 x double> undef, double %s2, i32 0 + %i2 = shufflevector <256 x double> %b2, <256 x double> poison, <256 x i32> zeroinitializer + %mul = call contract <256 x double> @llvm.vp.fmul.v256f64(<256 x double> %i0, <256 x double> %i1, <256 x i1> %m, i32 %n) + %fma = call contract <256 x double> @llvm.vp.fadd.v256f64(<256 x double> %mul, <256 x double> %i2, <256 x i1> %m, i32 %n) + ret <256 x double> %fma +} + +declare <256 x double> @llvm.vp.fadd.v256f64(<256 x double>, <256 x double>, <256 x i1>, i32) +declare <256 x double> @llvm.vp.fmul.v256f64(<256 x double>, <256 x double>, <256 x i1>, i32)