diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -154,6 +154,7 @@ FMUL_VL, FDIV_VL, FNEG_VL, + FMA_VL, // Set mask vector to all zeros or ones. VMCLR_VL, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -558,6 +558,7 @@ setOperationAction(ISD::FMUL, VT, Custom); setOperationAction(ISD::FDIV, VT, Custom); setOperationAction(ISD::FNEG, VT, Custom); + setOperationAction(ISD::FMA, VT, Custom); } } } @@ -1044,6 +1045,8 @@ return lowerToScalableOp(Op, DAG, RISCVISD::FDIV_VL); case ISD::FNEG: return lowerToScalableOp(Op, DAG, RISCVISD::FNEG_VL); + case ISD::FMA: + return lowerToScalableOp(Op, DAG, RISCVISD::FMA_VL); } } @@ -4575,6 +4578,7 @@ NODE_NAME_CASE(FMUL_VL) NODE_NAME_CASE(FDIV_VL) NODE_NAME_CASE(FNEG_VL) + NODE_NAME_CASE(FMA_VL) NODE_NAME_CASE(VMCLR_VL) NODE_NAME_CASE(VMSET_VL) NODE_NAME_CASE(VLE_VL) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -70,6 +70,15 @@ def riscv_fdiv_vl : SDNode<"RISCVISD::FDIV_VL", SDT_RISCVFPBinOp_VL>; def riscv_fneg_vl : SDNode<"RISCVISD::FNEG_VL", SDT_RISCVFPUnOp_VL>; +def SDT_RISCVVecFMA_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>, + SDTCisSameAs<0, 2>, + SDTCisSameAs<0, 3>, + SDTCisVec<0>, SDTCisFP<0>, + SDTCVecEltisVT<4, i1>, + SDTCisSameNumEltsAs<0, 4>, + SDTCisVT<5, XLenVT>]>; +def riscv_fma_vl : SDNode<"RISCVISD::FMA_VL", SDT_RISCVVecFMA_VL>; + def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCVecEltisVT<0, i1>, SDTCisVT<1, XLenVT>]>; @@ -178,7 +187,19 @@ defm "" : VPatBinaryFPVL_VV_VF; defm "" : VPatBinaryFPVL_VV_VF; -// 14.10. Vector Floating-Point Sign-Injection Instructions +// 14.6 Vector Single-Width Floating-Point Fused Multiply-Add Instructions. +foreach vti = AllFloatVectors in { + // NOTE: We choose VFMADD because it has the most commuting freedom. So it + // works best with how TwoAddressInstructionPass tries commuting. + def : Pat<(vti.Vector (riscv_fma_vl vti.RegClass:$rd, vti.RegClass:$rs1, + vti.RegClass:$rs2, (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast("PseudoVFMADD_VV_"# vti.LMul.MX) + vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, + GPR:$vl, vti.SEW)>; +} + +// 14.12. Vector Floating-Point Sign-Injection Instructions // Handle fneg with VFSGNJN using the same input for both operands. foreach vti = AllFloatVectors in { def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask true_mask), diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll @@ -253,6 +253,72 @@ ret void } +define void @fma_v8f16(<8 x half>* %x, <8 x half>* %y, <8 x half>* %z) { +; CHECK-LABEL: fma_v8f16: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a3, zero, 8 +; CHECK-NEXT: vsetvli a4, a3, e16,m1,ta,mu +; CHECK-NEXT: vle16.v v25, (a0) +; CHECK-NEXT: vle16.v v26, (a1) +; CHECK-NEXT: vle16.v v27, (a2) +; CHECK-NEXT: vsetvli a1, a3, e16,m1,tu,mu +; CHECK-NEXT: vfmadd.vv v25, v26, v27 +; CHECK-NEXT: vsetvli a1, a3, e16,m1,ta,mu +; CHECK-NEXT: vse16.v v25, (a0) +; CHECK-NEXT: ret + %a = load <8 x half>, <8 x half>* %x + %b = load <8 x half>, <8 x half>* %y + %c = load <8 x half>, <8 x half>* %z + %d = call <8 x half> @llvm.fma.v8f16(<8 x half> %a, <8 x half> %b, <8 x half> %c) + store <8 x half> %d, <8 x half>* %x + ret void +} +declare <8 x half> @llvm.fma.v8f16(<8 x half>, <8 x half>, <8 x half>) + +define void @fma_v4f32(<4 x float>* %x, <4 x float>* %y, <4 x float>* %z) { +; CHECK-LABEL: fma_v4f32: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a3, zero, 4 +; CHECK-NEXT: vsetvli a4, a3, e32,m1,ta,mu +; CHECK-NEXT: vle32.v v25, (a0) +; CHECK-NEXT: vle32.v v26, (a1) +; CHECK-NEXT: vle32.v v27, (a2) +; CHECK-NEXT: vsetvli a1, a3, e32,m1,tu,mu +; CHECK-NEXT: vfmadd.vv v25, v26, v27 +; CHECK-NEXT: vsetvli a1, a3, e32,m1,ta,mu +; CHECK-NEXT: vse32.v v25, (a0) +; CHECK-NEXT: ret + %a = load <4 x float>, <4 x float>* %x + %b = load <4 x float>, <4 x float>* %y + %c = load <4 x float>, <4 x float>* %z + %d = call <4 x float> @llvm.fma.v4f32(<4 x float> %a, <4 x float> %b, <4 x float> %c) + store <4 x float> %d, <4 x float>* %x + ret void +} +declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>) + +define void @fma_v2f64(<2 x double>* %x, <2 x double>* %y, <2 x double>* %z) { +; CHECK-LABEL: fma_v2f64: +; CHECK: # %bb.0: +; CHECK-NEXT: addi a3, zero, 2 +; CHECK-NEXT: vsetvli a4, a3, e64,m1,ta,mu +; CHECK-NEXT: vle64.v v25, (a0) +; CHECK-NEXT: vle64.v v26, (a1) +; CHECK-NEXT: vle64.v v27, (a2) +; CHECK-NEXT: vsetvli a1, a3, e64,m1,tu,mu +; CHECK-NEXT: vfmadd.vv v25, v26, v27 +; CHECK-NEXT: vsetvli a1, a3, e64,m1,ta,mu +; CHECK-NEXT: vse64.v v25, (a0) +; CHECK-NEXT: ret + %a = load <2 x double>, <2 x double>* %x + %b = load <2 x double>, <2 x double>* %y + %c = load <2 x double>, <2 x double>* %z + %d = call <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c) + store <2 x double> %d, <2 x double>* %x + ret void +} +declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>) + define void @fadd_v16f16(<16 x half>* %x, <16 x half>* %y) { ; LMULMAX2-LABEL: fadd_v16f16: ; LMULMAX2: # %bb.0: @@ -924,3 +990,132 @@ store <4 x double> %b, <4 x double>* %x ret void } + +define void @fma_v16f16(<16 x half>* %x, <16 x half>* %y, <16 x half>* %z) { +; LMULMAX2-LABEL: fma_v16f16: +; LMULMAX2: # %bb.0: +; LMULMAX2-NEXT: addi a3, zero, 16 +; LMULMAX2-NEXT: vsetvli a4, a3, e16,m2,ta,mu +; LMULMAX2-NEXT: vle16.v v26, (a0) +; LMULMAX2-NEXT: vle16.v v28, (a1) +; LMULMAX2-NEXT: vle16.v v30, (a2) +; LMULMAX2-NEXT: vsetvli a1, a3, e16,m2,tu,mu +; LMULMAX2-NEXT: vfmadd.vv v26, v28, v30 +; LMULMAX2-NEXT: vsetvli a1, a3, e16,m2,ta,mu +; LMULMAX2-NEXT: vse16.v v26, (a0) +; LMULMAX2-NEXT: ret +; +; LMULMAX1-LABEL: fma_v16f16: +; LMULMAX1: # %bb.0: +; LMULMAX1-NEXT: addi a3, zero, 8 +; LMULMAX1-NEXT: vsetvli a4, a3, e16,m1,ta,mu +; LMULMAX1-NEXT: vle16.v v25, (a0) +; LMULMAX1-NEXT: addi a4, a0, 16 +; LMULMAX1-NEXT: vle16.v v26, (a4) +; LMULMAX1-NEXT: vle16.v v27, (a1) +; LMULMAX1-NEXT: addi a1, a1, 16 +; LMULMAX1-NEXT: vle16.v v28, (a1) +; LMULMAX1-NEXT: addi a1, a2, 16 +; LMULMAX1-NEXT: vle16.v v29, (a1) +; LMULMAX1-NEXT: vle16.v v30, (a2) +; LMULMAX1-NEXT: vsetvli a1, a3, e16,m1,tu,mu +; LMULMAX1-NEXT: vfmadd.vv v26, v28, v29 +; LMULMAX1-NEXT: vfmadd.vv v25, v27, v30 +; LMULMAX1-NEXT: vsetvli a1, a3, e16,m1,ta,mu +; LMULMAX1-NEXT: vse16.v v25, (a0) +; LMULMAX1-NEXT: vse16.v v26, (a4) +; LMULMAX1-NEXT: ret + %a = load <16 x half>, <16 x half>* %x + %b = load <16 x half>, <16 x half>* %y + %c = load <16 x half>, <16 x half>* %z + %d = call <16 x half> @llvm.fma.v16f16(<16 x half> %a, <16 x half> %b, <16 x half> %c) + store <16 x half> %d, <16 x half>* %x + ret void +} +declare <16 x half> @llvm.fma.v16f16(<16 x half>, <16 x half>, <16 x half>) + +define void @fma_v8f32(<8 x float>* %x, <8 x float>* %y, <8 x float>* %z) { +; LMULMAX2-LABEL: fma_v8f32: +; LMULMAX2: # %bb.0: +; LMULMAX2-NEXT: addi a3, zero, 8 +; LMULMAX2-NEXT: vsetvli a4, a3, e32,m2,ta,mu +; LMULMAX2-NEXT: vle32.v v26, (a0) +; LMULMAX2-NEXT: vle32.v v28, (a1) +; LMULMAX2-NEXT: vle32.v v30, (a2) +; LMULMAX2-NEXT: vsetvli a1, a3, e32,m2,tu,mu +; LMULMAX2-NEXT: vfmadd.vv v26, v28, v30 +; LMULMAX2-NEXT: vsetvli a1, a3, e32,m2,ta,mu +; LMULMAX2-NEXT: vse32.v v26, (a0) +; LMULMAX2-NEXT: ret +; +; LMULMAX1-LABEL: fma_v8f32: +; LMULMAX1: # %bb.0: +; LMULMAX1-NEXT: addi a3, zero, 4 +; LMULMAX1-NEXT: vsetvli a4, a3, e32,m1,ta,mu +; LMULMAX1-NEXT: vle32.v v25, (a0) +; LMULMAX1-NEXT: addi a4, a0, 16 +; LMULMAX1-NEXT: vle32.v v26, (a4) +; LMULMAX1-NEXT: vle32.v v27, (a1) +; LMULMAX1-NEXT: addi a1, a1, 16 +; LMULMAX1-NEXT: vle32.v v28, (a1) +; LMULMAX1-NEXT: addi a1, a2, 16 +; LMULMAX1-NEXT: vle32.v v29, (a1) +; LMULMAX1-NEXT: vle32.v v30, (a2) +; LMULMAX1-NEXT: vsetvli a1, a3, e32,m1,tu,mu +; LMULMAX1-NEXT: vfmadd.vv v26, v28, v29 +; LMULMAX1-NEXT: vfmadd.vv v25, v27, v30 +; LMULMAX1-NEXT: vsetvli a1, a3, e32,m1,ta,mu +; LMULMAX1-NEXT: vse32.v v25, (a0) +; LMULMAX1-NEXT: vse32.v v26, (a4) +; LMULMAX1-NEXT: ret + %a = load <8 x float>, <8 x float>* %x + %b = load <8 x float>, <8 x float>* %y + %c = load <8 x float>, <8 x float>* %z + %d = call <8 x float> @llvm.fma.v8f32(<8 x float> %a, <8 x float> %b, <8 x float> %c) + store <8 x float> %d, <8 x float>* %x + ret void +} +declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) + +define void @fma_v4f64(<4 x double>* %x, <4 x double>* %y, <4 x double>* %z) { +; LMULMAX2-LABEL: fma_v4f64: +; LMULMAX2: # %bb.0: +; LMULMAX2-NEXT: addi a3, zero, 4 +; LMULMAX2-NEXT: vsetvli a4, a3, e64,m2,ta,mu +; LMULMAX2-NEXT: vle64.v v26, (a0) +; LMULMAX2-NEXT: vle64.v v28, (a1) +; LMULMAX2-NEXT: vle64.v v30, (a2) +; LMULMAX2-NEXT: vsetvli a1, a3, e64,m2,tu,mu +; LMULMAX2-NEXT: vfmadd.vv v26, v28, v30 +; LMULMAX2-NEXT: vsetvli a1, a3, e64,m2,ta,mu +; LMULMAX2-NEXT: vse64.v v26, (a0) +; LMULMAX2-NEXT: ret +; +; LMULMAX1-LABEL: fma_v4f64: +; LMULMAX1: # %bb.0: +; LMULMAX1-NEXT: addi a3, zero, 2 +; LMULMAX1-NEXT: vsetvli a4, a3, e64,m1,ta,mu +; LMULMAX1-NEXT: vle64.v v25, (a0) +; LMULMAX1-NEXT: addi a4, a0, 16 +; LMULMAX1-NEXT: vle64.v v26, (a4) +; LMULMAX1-NEXT: vle64.v v27, (a1) +; LMULMAX1-NEXT: addi a1, a1, 16 +; LMULMAX1-NEXT: vle64.v v28, (a1) +; LMULMAX1-NEXT: addi a1, a2, 16 +; LMULMAX1-NEXT: vle64.v v29, (a1) +; LMULMAX1-NEXT: vle64.v v30, (a2) +; LMULMAX1-NEXT: vsetvli a1, a3, e64,m1,tu,mu +; LMULMAX1-NEXT: vfmadd.vv v26, v28, v29 +; LMULMAX1-NEXT: vfmadd.vv v25, v27, v30 +; LMULMAX1-NEXT: vsetvli a1, a3, e64,m1,ta,mu +; LMULMAX1-NEXT: vse64.v v25, (a0) +; LMULMAX1-NEXT: vse64.v v26, (a4) +; LMULMAX1-NEXT: ret + %a = load <4 x double>, <4 x double>* %x + %b = load <4 x double>, <4 x double>* %y + %c = load <4 x double>, <4 x double>* %z + %d = call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> %b, <4 x double> %c) + store <4 x double> %d, <4 x double>* %x + ret void +} +declare <4 x double> @llvm.fma.v4f64(<4 x double>, <4 x double>, <4 x double>)