Index: llvm/lib/Target/AArch64/AArch64InstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -1234,6 +1234,10 @@ return Imm.isExactlyValue(+0.0); }]>; +def fpimm_minus0 : FPImmLeaf; + def fpimm_half : FPImmLeaf; Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -279,9 +279,18 @@ def AArch64fcvtzs_mt : SDNode<"AArch64ISD::FCVTZS_MERGE_PASSTHRU", SDT_AArch64FCVT>; def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>; -def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>; -def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>; -def AArch64fadda_p : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>; +def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>; +def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>; +def AArch64fadda_p_node : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>; + +def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3), + [(AArch64fadda_p_node node:$op1, node:$op2, node:$op3), + (AArch64fadda_p_node (nxv2i1 (SVEAllActive)), node:$op2, + (vselect node:$op1, node:$op3, (splat_vector (f32 fpimm_minus0)))), + (AArch64fadda_p_node (nxv4i1 (SVEAllActive)), node:$op2, + (vselect node:$op1, node:$op3, (splat_vector (f32 fpimm_minus0)))), + (AArch64fadda_p_node (nxv2i1 (SVEAllActive)), node:$op2, + (vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>; def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>; def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>; Index: llvm/test/CodeGen/AArch64/sve-fadda-select.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-fadda-select.ll @@ -0,0 +1,112 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s + +; Fold fadda(ptrue, x, select(mask, y, -0.0)) -> fadda(mask, x, y) + +define float @pred_fadda_nxv2f32(float %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, float -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call float @llvm.vector.reduce.fadd.nxv2f32(float %x, %sel) + ret float %fadda +} + +define float @pred_fadda_nxv4f32(float %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: // kill: def $s0 killed $s0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, float -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call float @llvm.vector.reduce.fadd.nxv4f32(float %x, %sel) + ret float %fadda +} + +define double @pred_fadda_nxv2f64(double %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0 +; CHECK-NEXT: fadda d0, p0, d0, z1.d +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, double -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call double @llvm.vector.reduce.fadd.nxv2f64(double %x, %sel) + ret double %fadda +} + +; Currently the folding doesn't work for f16 element types, since -0.0 is not treated as a legal f16 immediate. + +define half @pred_fadda_nxv2f16(half %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv2f16: +; CHECK: // %bb.0: +; CHECK-NEXT: adrp x8, .LCPI3_0 +; CHECK-NEXT: add x8, x8, :lo12:.LCPI3_0 +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ld1rh { z2.d }, p1/z, [x8] +; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d +; CHECK-NEXT: fadda h0, p1, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, half -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call half @llvm.vector.reduce.fadd.nxv2f16(half %x, %sel) + ret half %fadda +} + +define half @pred_fadda_nxv4f16(half %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv4f16: +; CHECK: // %bb.0: +; CHECK-NEXT: adrp x8, .LCPI4_0 +; CHECK-NEXT: add x8, x8, :lo12:.LCPI4_0 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ld1rh { z2.s }, p1/z, [x8] +; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s +; CHECK-NEXT: fadda h0, p1, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, half -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call half @llvm.vector.reduce.fadd.nxv4f16(half %x, %sel) + ret half %fadda +} + +define half @pred_fadda_nxv8f16(half %x, %y, %mask) { +; CHECK-LABEL: pred_fadda_nxv8f16: +; CHECK: // %bb.0: +; CHECK-NEXT: adrp x8, .LCPI5_0 +; CHECK-NEXT: add x8, x8, :lo12:.LCPI5_0 +; CHECK-NEXT: ptrue p1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0 +; CHECK-NEXT: ld1rh { z2.h }, p1/z, [x8] +; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h +; CHECK-NEXT: fadda h0, p1, h0, z1.h +; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0 +; CHECK-NEXT: ret + %i = insertelement poison, half -0.000000e+00, i32 0 + %minus0 = shufflevector %i, poison, zeroinitializer + %sel = select %mask, %y, %minus0 + %fadda = call half @llvm.vector.reduce.fadd.nxv8f16(half %x, %sel) + ret half %fadda +} + +declare float @llvm.vector.reduce.fadd.nxv2f32(float, ) +declare float @llvm.vector.reduce.fadd.nxv4f32(float, ) +declare double @llvm.vector.reduce.fadd.nxv2f64(double, ) +declare half @llvm.vector.reduce.fadd.nxv2f16(half, ) +declare half @llvm.vector.reduce.fadd.nxv4f16(half, ) +declare half @llvm.vector.reduce.fadd.nxv8f16(half, )