diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17196,12 +17196,42 @@ DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3)); } +// (vselect (p) (a) (op (a) (b))) => (vselect (!p) (op (a) (b)) (a)) +static Optional tryInvertVSelectWithSetCC(SDNode *N, + SelectionDAG &DAG) { + SDValue SetCC = N->getOperand(0); + if (SetCC.getOpcode() != ISD::SETCC || + SetCC.getOperand(0) != N->getOperand(1)) + return None; + + auto Opcode = N->getOperand(2).getOpcode(); + switch (Opcode) { + default: + return None; + case ISD::FMUL: + case ISD::FSUB: + case ISD::FADD: + break; + } + + ISD::CondCode CC = cast(SetCC->getOperand(2))->get(); + auto InverseSetCC = DAG.getSetCC( + SDLoc(SetCC), SetCC.getValueType(), SetCC.getOperand(0), + SetCC.getOperand(1), ISD::getSetCCInverse(CC, SetCC.getValueType())); + + return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0), + {InverseSetCC, N->getOperand(2), N->getOperand(1)}); +} + // vselect (v1i1 setcc) -> // vselect (v1iXX setcc) (XX is the size of the compared operand type) // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as // condition. If it can legalize "VSELECT v1i1" correctly, no need to combine // such VSELECT. static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) { + if (auto InvertResult = tryInvertVSelectWithSetCC(N, DAG)) + return InvertResult.getValue(); + SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); diff --git a/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll b/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll --- a/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll @@ -95,14 +95,13 @@ ; CHECK-NEXT: frsqrte z1.h, z0.h ; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: fmul z2.h, z1.h, z1.h -; CHECK-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0 +; CHECK-NEXT: fcmne p0.h, p0/z, z0.h, #0.0 ; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h ; CHECK-NEXT: fmul z1.h, z1.h, z2.h ; CHECK-NEXT: fmul z2.h, z1.h, z1.h ; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h ; CHECK-NEXT: fmul z1.h, z1.h, z2.h -; CHECK-NEXT: fmul z1.h, z0.h, z1.h -; CHECK-NEXT: sel z0.h, p0, z0.h, z1.h +; CHECK-NEXT: fmul z0.h, p0/m, z0.h, z1.h ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv8f16( %a) ret %fsqrt @@ -124,14 +123,13 @@ ; CHECK-NEXT: frsqrte z1.s, z0.s ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: fmul z2.s, z1.s, z1.s -; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 ; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s ; CHECK-NEXT: fmul z1.s, z1.s, z2.s ; CHECK-NEXT: fmul z2.s, z1.s, z1.s ; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s ; CHECK-NEXT: fmul z1.s, z1.s, z2.s -; CHECK-NEXT: fmul z1.s, z0.s, z1.s -; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv4f32( %a) ret %fsqrt @@ -153,7 +151,7 @@ ; CHECK-NEXT: frsqrte z1.d, z0.d ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: fmul z2.d, z1.d, z1.d -; CHECK-NEXT: fcmeq p0.d, p0/z, z0.d, #0.0 +; CHECK-NEXT: fcmne p0.d, p0/z, z0.d, #0.0 ; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d ; CHECK-NEXT: fmul z1.d, z1.d, z2.d ; CHECK-NEXT: fmul z2.d, z1.d, z1.d @@ -162,8 +160,7 @@ ; CHECK-NEXT: fmul z2.d, z1.d, z1.d ; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d ; CHECK-NEXT: fmul z1.d, z1.d, z2.d -; CHECK-NEXT: fmul z1.d, z0.d, z1.d -; CHECK-NEXT: sel z0.d, p0, z0.d, z1.d +; CHECK-NEXT: fmul z0.d, p0/m, z0.d, z1.d ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv2f64( %a) ret %fsqrt diff --git a/llvm/test/CodeGen/AArch64/sve-select.ll b/llvm/test/CodeGen/AArch64/sve-select.ll --- a/llvm/test/CodeGen/AArch64/sve-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-select.ll @@ -542,3 +542,71 @@ %sel = select i1 %mask, %a, %b ret %sel } + +define @fcmp_select_f32_invert_fmul( %p, %a, %b) #0 { +; CHECK-LABEL: fcmp_select_f32_invert_fmul: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %fcmp = fcmp oeq %a, zeroinitializer + %fmul = fmul %a, %b + %sel = select %fcmp, %a, %fmul + ret %sel +} + +define @fcmp_select_f32_invert_fadd( %p, %a, %b) #0 { +; CHECK-LABEL: fcmp_select_f32_invert_fadd: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fadd z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %fcmp = fcmp oeq %a, zeroinitializer + %fadd = fadd %a, %b + %sel = select %fcmp, %a, %fadd + ret %sel +} + +define @fcmp_select_f32_invert_fsub( %p, %a, %b) #0 { +; CHECK-LABEL: fcmp_select_f32_invert_fsub: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fsub z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %fcmp = fcmp oeq %a, zeroinitializer + %fsub = fsub %a, %b + %sel = select %fcmp, %a, %fsub + ret %sel +} + +define @fcmp_select_f32_no_invert( %p, %a, %b) #0 { +; CHECK-LABEL: fcmp_select_f32_no_invert: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %fcmp = fcmp oeq %a, zeroinitializer + %fmul = fmul %a, %b + %sel = select %fcmp, %fmul, %a + ret %sel +} + +define @fcmp_select_f32_no_invert_double_op( %p, %a, %b, %c, %d) #0 { +; CHECK-LABEL: fcmp_select_f32_no_invert_double_op: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmul z2.s, z2.s, z3.s +; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, z0.s, z1.s +; CHECK-NEXT: sel z0.s, p0, z0.s, z2.s +; CHECK-NEXT: ret + %fcmp = fcmp oeq %a, zeroinitializer + %fmul1 = fmul %a, %b + %fmul2 = fmul %c, %d + %sel = select %fcmp, %fmul1, %fmul2 + ret %sel +}