Index: llvm/lib/Target/AArch64/AArch64InstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -1234,6 +1234,10 @@ return Imm.isExactlyValue(+0.0); }]>; +def fpimm_minus0 : FPImmLeaf; + def fpimm_half : FPImmLeaf; Index: llvm/lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/SVEInstrFormats.td +++ llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -466,6 +466,13 @@ (inst $Op1, $Op2, vt3:$Op3)>; } +class SVE_3_Op_Pat_Pred_FPSelMinusZero +: Pat<(vtd (op (vt1 (SVEAllActive)), vt2:$Op2, (vselect vt1:$Pg, vt3:$Op3, + (vt3 (splat_vector (scalarty fpimm_minus0)))))), + (inst $Pg, $Op2, $Op3)>; + // // Common but less generic patterns. // @@ -5118,6 +5125,14 @@ def : SVE_3_Op_Pat(NAME # _S)>; def : SVE_3_Op_Pat(NAME # _S)>; def : SVE_3_Op_Pat(NAME # _D)>; + + def : SVE_3_Op_Pat_Pred_FPSelMinusZero(NAME # _S)>; + def : SVE_3_Op_Pat_Pred_FPSelMinusZero(NAME # _S)>; + def : SVE_3_Op_Pat_Pred_FPSelMinusZero(NAME # _D)>; + } //===----------------------------------------------------------------------===// Index: llvm/test/CodeGen/AArch64/sve-fadda-select.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-fadda-select.ll @@ -0,0 +1,112 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s + +; Fold fadda(ptrue, x, select(mask, y, -0.0)) -> fadda(mask, x, y) + +define float @pred_fadda_nxv2f32(float %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, float -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call float @llvm.vector.reduce.fadd.nxv2f32(float %x, %sel) + ret float %fadda +} + +define float @pred_fadda_nxv4f32(float %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, float -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, %sel) + ret float %fadda +} + +define double @pred_fadda_nxv2f64(double %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: fadda d0, p0, d0, z1.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, double -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call double @llvm.vector.reduce.fadd.nxv2f64(double %x, %sel) + ret double %fadda +} + +; Currently the folding doesn't work for f16 element types, since -0.0 is not treated as a legal f16 immediate. + +define half @pred_fadda_nxv2f16(half %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv2f16: +; CHECK: // %bb.0: +; CHECK-NEXT: adrp x8, .LCPI3_0 +; CHECK-NEXT: add x8, x8, :lo12:.LCPI3_0 +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ld1rh { z2.d }, p1/z, [x8] +; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d +; CHECK-NEXT: fadda h0, p1, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, half -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call half @llvm.vector.reduce.fadd.nxv2f16(half %x, %sel) + ret half %fadda +} + +define half @pred_fadda_nxv4f16(half %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: adrp x8, .LCPI4_0 +; CHECK-NEXT: add x8, x8, :lo12:.LCPI4_0 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ld1rh { z2.s }, p1/z, [x8] +; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s +; CHECK-NEXT: fadda h0, p1, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, half -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call half @llvm.vector.reduce.fadd.nxv4f16(half %x, %sel) + ret half %fadda +} + +define half @pred_fadda_nxv8f16(half %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv8f16: +; CHECK: // %bb.0: +; CHECK-NEXT: adrp x8, .LCPI5_0 +; CHECK-NEXT: add x8, x8, :lo12:.LCPI5_0 +; CHECK-NEXT: ptrue p1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ld1rh { z2.h }, p1/z, [x8] +; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h +; CHECK-NEXT: fadda h0, p1, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, half -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call half @llvm.vector.reduce.fadd.nxv8f16(half %x, %sel) + ret half %fadda +} + +declare float @llvm.vector.reduce.fadd.nxv2f32(float, ) +declare float @llvm.vector.reduce.fadd.nxv4f32(float, ) +declare double @llvm.vector.reduce.fadd.nxv2f64(double, ) +declare half @llvm.vector.reduce.fadd.nxv2f16(half, ) +declare half @llvm.vector.reduce.fadd.nxv4f16(half, ) +declare half @llvm.vector.reduce.fadd.nxv8f16(half, )