diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -6593,6 +6593,11 @@ unsigned NewOp = Node->getOpcode() == ISD::FMINNUM ? ISD::FMINNUM_IEEE : ISD::FMAXNUM_IEEE; EVT VT = Node->getValueType(0); + + if (VT.isScalableVector()) + report_fatal_error( + "Expanding fminnum/fmaxnum for scalable vectors is undefined."); + if (isOperationLegalOrCustom(NewOp, VT)) { SDValue Quiet0 = Node->getOperand(0); SDValue Quiet1 = Node->getOperand(1); @@ -8142,6 +8147,11 @@ EVT VT = VecOp.getValueType(); EVT EltVT = VT.getVectorElementType(); + + if (VT.isScalableVector()) + report_fatal_error( + "Expanding reductions for scalable vectors is undefined."); + unsigned NumElts = VT.getVectorNumElements(); SmallVector Ops; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1147,6 +1147,10 @@ setOperationAction(ISD::FABS, VT, Custom); setOperationAction(ISD::FP_EXTEND, VT, Custom); setOperationAction(ISD::FP_ROUND, VT, Custom); + setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); } for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) @@ -16765,18 +16769,18 @@ EVT SrcVT = VecOp.getValueType(); EVT ResVT = SrcVT.getVectorElementType(); - // Only fixed length FADDA handled for now. - if (!useSVEForFixedLengthVectorVT(SrcVT, /*OverrideNEON=*/true)) - return SDValue(); + EVT ContainerVT = SrcVT; + if (SrcVT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT); + VecOp = convertToScalableVector(DAG, ContainerVT, VecOp); + } SDValue Pg = getPredicateForVector(DAG, DL, SrcVT); - EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT); SDValue Zero = DAG.getConstant(0, DL, MVT::i64); // Convert operands to Scalable. AccOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, DAG.getUNDEF(ContainerVT), AccOp, Zero); - VecOp = convertToScalableVector(DAG, ContainerVT, VecOp); // Perform reduction. SDValue Rdx = DAG.getNode(AArch64ISD::FADDA_PRED, DL, ContainerVT, @@ -16833,9 +16837,12 @@ // UADDV always returns an i64 result. EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 : SrcVT.getVectorElementType(); + EVT RdxVT = SrcVT; + if (SrcVT.isFixedLengthVector() || Opcode == AArch64ISD::UADDV_PRED) + RdxVT = getPackedSVEVectorVT(ResVT); SDValue Pg = getPredicateForVector(DAG, DL, SrcVT); - SDValue Rdx = DAG.getNode(Opcode, DL, getPackedSVEVectorVT(ResVT), Pg, VecOp); + SDValue Rdx = DAG.getNode(Opcode, DL, RdxVT, Pg, VecOp); SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, DAG.getConstant(0, DL, MVT::i64)); diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -4551,7 +4551,10 @@ def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32asZPR>; def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64asZPR>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _H)>; def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; def : SVE_2_Op_Pat(NAME # _S)>; def : SVE_2_Op_Pat(NAME # _D)>; } @@ -4587,7 +4590,10 @@ def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32asZPR>; def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64asZPR>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _H)>; def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; def : SVE_3_Op_Pat(NAME # _S)>; def : SVE_3_Op_Pat(NAME # _D)>; } diff --git a/llvm/test/CodeGen/AArch64/sve-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fp-reduce.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-fp-reduce.ll @@ -0,0 +1,256 @@ +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -asm-verbose=0 < %s 2>%t | FileCheck %s +; RUN: FileCheck --check-prefix=WARN --allow-empty %s <%t + +; If this check fails please read test/CodeGen/AArch64/README for instructions on how to resolve it. +; WARN-NOT: warning + +; FADD + +define half @fadda_nxv2f16(half %init, %a) { +; CHECK-LABEL: fadda_nxv2f16: +; CHECK: ptrue p0.d +; CHECK-NEXT: fadda h0, p0, h0, z1.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fadd.nxv2f16(half %init, %a) + ret half %res +} + +define half @fadda_nxv4f16(half %init, %a) { +; CHECK-LABEL: fadda_nxv4f16: +; CHECK: ptrue p0.s +; CHECK-NEXT: fadda h0, p0, h0, z1.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fadd.nxv4f16(half %init, %a) + ret half %res +} + +define half @fadda_nxv8f16(half %init, %a) { +; CHECK-LABEL: fadda_nxv8f16: +; CHECK: ptrue p0.h +; CHECK-NEXT: fadda h0, p0, h0, z1.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fadd.nxv8f16(half %init, %a) + ret half %res +} + +define float @fadda_nxv2f32(float %init, %a) { +; CHECK-LABEL: fadda_nxv2f32: +; CHECK: ptrue p0.d +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: ret + %res = call float @llvm.vector.reduce.fadd.nxv2f32(float %init, %a) + ret float %res +} + +define float @fadda_nxv4f32(float %init, %a) { +; CHECK-LABEL: fadda_nxv4f32: +; CHECK: ptrue p0.s +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: ret + %res = call float @llvm.vector.reduce.fadd.nxv4f32(float %init, %a) + ret float %res +} + +define double @fadda_nxv2f64(double %init, %a) { +; CHECK-LABEL: fadda_nxv2f64: +; CHECK: ptrue p0.d +; CHECK-NEXT: fadda d0, p0, d0, z1.d +; CHECK-NEXT: ret + %res = call double @llvm.vector.reduce.fadd.nxv2f64(double %init, %a) + ret double %res +} + +; FADDV + +define half @faddv_nxv2f16(half %init, %a) { +; CHECK-LABEL: faddv_nxv2f16: +; CHECK: ptrue p0.d +; CHECK-NEXT: faddv h1, p0, z1.h +; CHECK-NEXT: fadd h0, h0, h1 +; CHECK-NEXT: ret + %res = call fast half @llvm.vector.reduce.fadd.nxv2f16(half %init, %a) + ret half %res +} + +define half @faddv_nxv4f16(half %init, %a) { +; CHECK-LABEL: faddv_nxv4f16: +; CHECK: ptrue p0.s +; CHECK-NEXT: faddv h1, p0, z1.h +; CHECK-NEXT: fadd h0, h0, h1 +; CHECK-NEXT: ret + %res = call fast half @llvm.vector.reduce.fadd.nxv4f16(half %init, %a) + ret half %res +} + +define half @faddv_nxv8f16(half %init, %a) { +; CHECK-LABEL: faddv_nxv8f16: +; CHECK: ptrue p0.h +; CHECK-NEXT: faddv h1, p0, z1.h +; CHECK-NEXT: fadd h0, h0, h1 +; CHECK-NEXT: ret + %res = call fast half @llvm.vector.reduce.fadd.nxv8f16(half %init, %a) + ret half %res +} + +define float @faddv_nxv2f32(float %init, %a) { +; CHECK-LABEL: faddv_nxv2f32: +; CHECK: ptrue p0.d +; CHECK-NEXT: faddv s1, p0, z1.s +; CHECK-NEXT: fadd s0, s0, s1 +; CHECK-NEXT: ret + %res = call fast float @llvm.vector.reduce.fadd.nxv2f32(float %init, %a) + ret float %res +} + +define float @faddv_nxv4f32(float %init, %a) { +; CHECK-LABEL: faddv_nxv4f32: +; CHECK: ptrue p0.s +; CHECK-NEXT: faddv s1, p0, z1.s +; CHECK-NEXT: fadd s0, s0, s1 +; CHECK-NEXT: ret + %res = call fast float @llvm.vector.reduce.fadd.nxv4f32(float %init, %a) + ret float %res +} + +define double @faddv_nxv2f64(double %init, %a) { +; CHECK-LABEL: faddv_nxv2f64: +; CHECK: ptrue p0.d +; CHECK-NEXT: faddv d1, p0, z1.d +; CHECK-NEXT: fadd d0, d0, d1 +; CHECK-NEXT: ret + %res = call fast double @llvm.vector.reduce.fadd.nxv2f64(double %init, %a) + ret double %res +} + +; FMAXV + +define half @fmaxv_nxv2f16( %a) { +; CHECK-LABEL: fmaxv_nxv2f16: +; CHECK: ptrue p0.d +; CHECK-NEXT: fmaxnmv h0, p0, z0.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fmax.nxv2f16( %a) + ret half %res +} + +define half @fmaxv_nxv4f16( %a) { +; CHECK-LABEL: fmaxv_nxv4f16: +; CHECK: ptrue p0.s +; CHECK-NEXT: fmaxnmv h0, p0, z0.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fmax.nxv4f16( %a) + ret half %res +} + +define half @fmaxv_nxv8f16( %a) { +; CHECK-LABEL: fmaxv_nxv8f16: +; CHECK: ptrue p0.h +; CHECK-NEXT: fmaxnmv h0, p0, z0.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fmax.nxv8f16( %a) + ret half %res +} + +define float @fmaxv_nxv2f32( %a) { +; CHECK-LABEL: fmaxv_nxv2f32: +; CHECK: ptrue p0.d +; CHECK-NEXT: fmaxnmv s0, p0, z0.s +; CHECK-NEXT: ret + %res = call float @llvm.vector.reduce.fmax.nxv2f32( %a) + ret float %res +} + +define float @fmaxv_nxv4f32( %a) { +; CHECK-LABEL: fmaxv_nxv4f32: +; CHECK: ptrue p0.s +; CHECK-NEXT: fmaxnmv s0, p0, z0.s +; CHECK-NEXT: ret + %res = call float @llvm.vector.reduce.fmax.nxv4f32( %a) + ret float %res +} + +define double @fmaxv_nxv2f64( %a) { +; CHECK-LABEL: fmaxv_nxv2f64: +; CHECK: ptrue p0.d +; CHECK-NEXT: fmaxnmv d0, p0, z0.d +; CHECK-NEXT: ret + %res = call double @llvm.vector.reduce.fmax.nxv2f64( %a) + ret double %res +} + +; FMINV + +define half @fminv_nxv2f16( %a) { +; CHECK-LABEL: fminv_nxv2f16: +; CHECK: ptrue p0.d +; CHECK-NEXT: fminnmv h0, p0, z0.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fmin.nxv2f16( %a) + ret half %res +} + +define half @fminv_nxv4f16( %a) { +; CHECK-LABEL: fminv_nxv4f16: +; CHECK: ptrue p0.s +; CHECK-NEXT: fminnmv h0, p0, z0.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fmin.nxv4f16( %a) + ret half %res +} + +define half @fminv_nxv8f16( %a) { +; CHECK-LABEL: fminv_nxv8f16: +; CHECK: ptrue p0.h +; CHECK-NEXT: fminnmv h0, p0, z0.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fmin.nxv8f16( %a) + ret half %res +} + +define float @fminv_nxv2f32( %a) { +; CHECK-LABEL: fminv_nxv2f32: +; CHECK: ptrue p0.d +; CHECK-NEXT: fminnmv s0, p0, z0.s +; CHECK-NEXT: ret + %res = call float @llvm.vector.reduce.fmin.nxv2f32( %a) + ret float %res +} + +define float @fminv_nxv4f32( %a) { +; CHECK-LABEL: fminv_nxv4f32: +; CHECK: ptrue p0.s +; CHECK-NEXT: fminnmv s0, p0, z0.s +; CHECK-NEXT: ret + %res = call float @llvm.vector.reduce.fmin.nxv4f32( %a) + ret float %res +} + +define double @fminv_nxv2f64( %a) { +; CHECK-LABEL: fminv_nxv2f64: +; CHECK: ptrue p0.d +; CHECK-NEXT: fminnmv d0, p0, z0.d +; CHECK-NEXT: ret + %res = call double @llvm.vector.reduce.fmin.nxv2f64( %a) + ret double %res +} + +declare half @llvm.vector.reduce.fadd.nxv2f16(half, ) +declare half @llvm.vector.reduce.fadd.nxv4f16(half, ) +declare half @llvm.vector.reduce.fadd.nxv8f16(half, ) +declare float @llvm.vector.reduce.fadd.nxv2f32(float, ) +declare float @llvm.vector.reduce.fadd.nxv4f32(float, ) +declare double @llvm.vector.reduce.fadd.nxv2f64(double, ) + +declare half @llvm.vector.reduce.fmax.nxv2f16() +declare half @llvm.vector.reduce.fmax.nxv4f16() +declare half @llvm.vector.reduce.fmax.nxv8f16() +declare float @llvm.vector.reduce.fmax.nxv2f32() +declare float @llvm.vector.reduce.fmax.nxv4f32() +declare double @llvm.vector.reduce.fmax.nxv2f64() + +declare half @llvm.vector.reduce.fmin.nxv2f16() +declare half @llvm.vector.reduce.fmin.nxv4f16() +declare half @llvm.vector.reduce.fmin.nxv8f16() +declare float @llvm.vector.reduce.fmin.nxv2f32() +declare float @llvm.vector.reduce.fmin.nxv4f32() +declare double @llvm.vector.reduce.fmin.nxv2f64() diff --git a/llvm/test/CodeGen/AArch64/sve-split-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-split-fp-reduce.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-split-fp-reduce.ll @@ -0,0 +1,65 @@ +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -asm-verbose=0 < %s 2>%t | FileCheck %s +; RUN: FileCheck --check-prefix=WARN --allow-empty %s <%t + +; If this check fails please read test/CodeGen/AArch64/README for instructions on how to resolve it. +; WARN-NOT: warning + +; FADD + +define double @fadda_nxv8f64(double %init, %a) { +; CHECK-LABEL: fadda_nxv8f64 +; CHECK: ptrue p0.d +; CHECK-NEXT: fadda d0, p0, d0, z1.d +; CHECK-NEXT: fadda d0, p0, d0, z2.d +; CHECK-NEXT: fadda d0, p0, d0, z3.d +; CHECK-NEXT: fadda d0, p0, d0, z4.d +; CHECK-NEXT: ret + %res = call double @llvm.vector.reduce.fadd.nxv8f64(double %init, %a) + ret double %res +} + +; FADDV + +define float @faddv_nxv8f32(float %init, %a) { +; CHECK-LABEL: faddv_nxv8f32: +; CHECK: fadd z1.s, z1.s, z2.s +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: faddv s1, p0, z1.s +; CHECK-NEXT: fadd s0, s0, s1 +; CHECK-NEXT: ret + %res = call fast float @llvm.vector.reduce.fadd.nxv8f32(float %init, %a) + ret float %res +} + +; FMAXV + +define double @fmaxv_nxv8f64( %a) { +; CHECK-LABEL: fmaxv_nxv8f64: +; CHECK: ptrue p0.d +; CHECK-NEXT: fmaxnm z1.d, p0/m, z1.d, z3.d +; CHECK-NEXT: fmaxnm z0.d, p0/m, z0.d, z2.d +; CHECK-NEXT: fmaxnm z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: fmaxnmv d0, p0, z0.d +; CHECK-NEXT: ret + %res = call double @llvm.vector.reduce.fmax.nxv8f64( %a) + ret double %res +} + +; FMINV + +define half @fminv_nxv16f16( %a) { +; CHECK-LABEL: fminv_nxv16f16: +; CHECK: ptrue p0.h +; CHECK-NEXT: fminnm z0.h, p0/m, z0.h, z1.h +; CHECK-NEXT: fminnmv h0, p0, z0.h +; CHECK-NEXT: ret + %res = call half @llvm.vector.reduce.fmin.nxv16f16( %a) + ret half %res +} + +declare double @llvm.vector.reduce.fadd.nxv8f64(double, ) +declare float @llvm.vector.reduce.fadd.nxv8f32(float, ) + +declare double @llvm.vector.reduce.fmax.nxv8f64() + +declare half @llvm.vector.reduce.fmin.nxv16f16()