diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1380,6 +1380,11 @@ setOperationAction(ISD::MUL, MVT::v1i64, Custom); setOperationAction(ISD::MUL, MVT::v2i64, Custom); + // NEON doesn't support across-vector reductions, but SVE does. + for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64, + MVT::v2f64}) + setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); + // NOTE: Currently this has to happen after computeRegisterProperties rather // than the preferred option of combining it with the addRegisterClass call. if (Subtarget->useSVEForFixedLengthVectors()) { @@ -1433,10 +1438,6 @@ setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); } - // FP operations with no NEON support. - for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, - MVT::v1f64, MVT::v2f64}) - setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); // Use SVE for vectors with more than 2 elements. for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32}) diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fadd-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fadd-reduce.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fadd-reduce.ll @@ -0,0 +1,88 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mattr=+sve -mtriple=aarch64 %s -o - | FileCheck %s --check-prefixes=CHECK + +define half @pred_fadda_v4f16(half %x, <4 x half> %y) { +; CHECK-LABEL: pred_fadda_v4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ptrue p0.h, vl4 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: fadda h0, p0, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 +; CHECK-NEXT: ret + %fadda = call half @llvm.vector.reduce.fadd.v4f16(half %x, <4 x half> %y) + ret half %fadda +} + +define half @pred_fadda_v8f16(half %x, <8 x half> %y) { +; CHECK-LABEL: pred_fadda_v8f16: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ptrue p0.h, vl8 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fadda h0, p0, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 +; CHECK-NEXT: ret + %fadda = call half @llvm.vector.reduce.fadd.v8f16(half %x, <8 x half> %y) + ret half %fadda +} + +define float @pred_fadda_v2f32(float %x, <2 x float> %y) { +; CHECK-LABEL: pred_fadda_v2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: ptrue p0.s, vl2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 +; CHECK-NEXT: ret + %fadda = call float @llvm.vector.reduce.fadd.v2f32(float %x, <2 x float> %y) + ret float %fadda +} + +define float @pred_fadda_v4f32(float %x, <4 x float> %y) { +; CHECK-LABEL: pred_fadda_v4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 +; CHECK-NEXT: ret + %fadda = call float @llvm.vector.reduce.fadd.v4f32(float %x, <4 x float> %y) + ret float %fadda +} + +define double @pred_fadda_v1f64(double %x, <1 x double> %y) { +; CHECK-LABEL: pred_fadda_v1f64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl1 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: fadda d0, p0, d0, z1.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %fadda = call double @llvm.vector.reduce.fadd.v1f64(double %x, <1 x double> %y) + ret double %fadda +} + +define double @pred_fadda_v2f64(double %x, <2 x double> %y) { +; CHECK-LABEL: pred_fadda_v2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fadda d0, p0, d0, z1.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %fadda = call double @llvm.vector.reduce.fadd.v2f64(double %x, <2 x double> %y) + ret double %fadda +} + + +declare half @llvm.vector.reduce.fadd.v4f16(half, <4 x half>) +declare half @llvm.vector.reduce.fadd.v8f16(half, <8 x half>) +declare float @llvm.vector.reduce.fadd.v2f32(float, <2 x float>) +declare float @llvm.vector.reduce.fadd.v4f32(float, <4 x float>) +declare double @llvm.vector.reduce.fadd.v1f64(double, <1 x double>) +declare double @llvm.vector.reduce.fadd.v2f64(double, <2 x double>)