diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -146,7 +146,7 @@ v64f16 = 86, // 64 x f16 v128f16 = 87, // 128 x f16 v256f16 = 88, // 256 x f16 - v512f16 = 89, // 256 x f16 + v512f16 = 89, // 512 x f16 v2bf16 = 90, // 2 x bf16 v3bf16 = 91, // 3 x bf16 diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1380,6 +1380,10 @@ setOperationAction(ISD::MUL, MVT::v1i64, Custom); setOperationAction(ISD::MUL, MVT::v2i64, Custom); + // NEON doesn't support across-vector reductions, but SVE does. + for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v2f64}) + setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); + // NOTE: Currently this has to happen after computeRegisterProperties rather // than the preferred option of combining it with the addRegisterClass call. if (Subtarget->useSVEForFixedLengthVectors()) { @@ -1433,10 +1437,6 @@ setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); } - // FP operations with no NEON support. - for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, - MVT::v1f64, MVT::v2f64}) - setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); // Use SVE for vectors with more than 2 elements. for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32}) diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll @@ -13,14 +13,11 @@ define half @fadda_v4f16(half %start, <4 x half> %a) vscale_range(1,0) #0 { ; CHECK-LABEL: fadda_v4f16: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 -; CHECK-NEXT: mov h2, v1.h[1] -; CHECK-NEXT: fadd h0, h0, h1 -; CHECK-NEXT: mov h3, v1.h[2] -; CHECK-NEXT: mov h1, v1.h[3] -; CHECK-NEXT: fadd h0, h0, h2 -; CHECK-NEXT: fadd h0, h0, h3 -; CHECK-NEXT: fadd h0, h0, h1 +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ptrue p0.h, vl4 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: fadda h0, p0, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 ; CHECK-NEXT: ret %res = call half @llvm.vector.reduce.fadd.v4f16(half %start, <4 x half> %a) ret half %res @@ -30,21 +27,11 @@ define half @fadda_v8f16(half %start, <8 x half> %a) vscale_range(1,0) #0 { ; CHECK-LABEL: fadda_v8f16: ; CHECK: // %bb.0: -; CHECK-NEXT: mov h2, v1.h[1] -; CHECK-NEXT: fadd h0, h0, h1 -; CHECK-NEXT: mov h3, v1.h[2] -; CHECK-NEXT: fadd h0, h0, h2 -; CHECK-NEXT: mov h2, v1.h[3] -; CHECK-NEXT: fadd h0, h0, h3 -; CHECK-NEXT: mov h3, v1.h[4] -; CHECK-NEXT: fadd h0, h0, h2 -; CHECK-NEXT: mov h2, v1.h[5] -; CHECK-NEXT: fadd h0, h0, h3 -; CHECK-NEXT: mov h3, v1.h[6] -; CHECK-NEXT: mov h1, v1.h[7] -; CHECK-NEXT: fadd h0, h0, h2 -; CHECK-NEXT: fadd h0, h0, h3 -; CHECK-NEXT: fadd h0, h0, h1 +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ptrue p0.h, vl8 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fadda h0, p0, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 ; CHECK-NEXT: ret %res = call half @llvm.vector.reduce.fadd.v8f16(half %start, <8 x half> %a) ret half %res @@ -122,10 +109,11 @@ define float @fadda_v2f32(float %start, <2 x float> %a) vscale_range(1,0) #0 { ; CHECK-LABEL: fadda_v2f32: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 -; CHECK-NEXT: mov s2, v1.s[1] -; CHECK-NEXT: fadd s0, s0, s1 -; CHECK-NEXT: fadd s0, s0, s2 +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: ptrue p0.s, vl2 +; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 ; CHECK-NEXT: ret %res = call float @llvm.vector.reduce.fadd.v2f32(float %start, <2 x float> %a) ret float %res @@ -135,13 +123,11 @@ define float @fadda_v4f32(float %start, <4 x float> %a) vscale_range(1,0) #0 { ; CHECK-LABEL: fadda_v4f32: ; CHECK: // %bb.0: -; CHECK-NEXT: mov s2, v1.s[1] -; CHECK-NEXT: fadd s0, s0, s1 -; CHECK-NEXT: mov s3, v1.s[2] -; CHECK-NEXT: mov s1, v1.s[3] -; CHECK-NEXT: fadd s0, s0, s2 -; CHECK-NEXT: fadd s0, s0, s3 -; CHECK-NEXT: fadd s0, s0, s1 +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 ; CHECK-NEXT: ret %res = call float @llvm.vector.reduce.fadd.v4f32(float %start, <4 x float> %a) ret float %res @@ -229,9 +215,11 @@ define double @fadda_v2f64(double %start, <2 x double> %a) vscale_range(1,0) #0 { ; CHECK-LABEL: fadda_v2f64: ; CHECK: // %bb.0: -; CHECK-NEXT: mov d2, v1.d[1] -; CHECK-NEXT: fadd d0, d0, d1 -; CHECK-NEXT: fadd d0, d0, d2 +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1 +; CHECK-NEXT: fadda d0, p0, d0, z1.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 ; CHECK-NEXT: ret %res = call double @llvm.vector.reduce.fadd.v2f64(double %start, <2 x double> %a) ret double %res