diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17196,12 +17196,47 @@ DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3)); } +// (vselect (setcc (a) (0)) (a) (op (a) (b))) +// => (vselect (not setcc (a) (0)) (op (a) (b)) (a)) +static Optional tryInvertVSelectWithSetCC(SDNode *N, + SelectionDAG &DAG) { + SDValue SetCC = N->getOperand(0); + auto SetCCOp0 = SetCC.getOperand(0); + auto NOp1 = N->getOperand(1); + auto NOp2 = N->getOperand(2); + + if (SetCC.getOpcode() != ISD::SETCC || SetCCOp0 != NOp1) + return None; + if (SetCCOp0 == NOp2) + return None; + + auto Opcode = NOp2.getOpcode(); + switch (Opcode) { + default: + return None; + case ISD::FMUL: + case ISD::FSUB: + case ISD::FADD: + break; + } + ISD::CondCode CC = cast(SetCC->getOperand(2))->get(); + auto InverseSetCC = DAG.getSetCC( + SDLoc(SetCC), SetCC.getValueType(), SetCCOp0, SetCC.getOperand(1), + ISD::getSetCCInverse(CC, SetCC.getValueType())); + + return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0), + {InverseSetCC, NOp2, NOp1}); +} + // vselect (v1i1 setcc) -> // vselect (v1iXX setcc) (XX is the size of the compared operand type) // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as // condition. If it can legalize "VSELECT v1i1" correctly, no need to combine // such VSELECT. static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) { + if (auto InvertResult = tryInvertVSelectWithSetCC(N, DAG)) + return InvertResult.getValue(); + SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); diff --git a/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll b/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll --- a/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll @@ -95,14 +95,13 @@ ; CHECK-NEXT: frsqrte z1.h, z0.h ; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: fmul z2.h, z1.h, z1.h -; CHECK-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0 +; CHECK-NEXT: fcmne p0.h, p0/z, z0.h, #0.0 ; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h ; CHECK-NEXT: fmul z1.h, z1.h, z2.h ; CHECK-NEXT: fmul z2.h, z1.h, z1.h ; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h ; CHECK-NEXT: fmul z1.h, z1.h, z2.h -; CHECK-NEXT: fmul z1.h, z0.h, z1.h -; CHECK-NEXT: sel z0.h, p0, z0.h, z1.h +; CHECK-NEXT: fmul z0.h, p0/m, z0.h, z1.h ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv8f16( %a) ret %fsqrt @@ -124,14 +123,13 @@ ; CHECK-NEXT: frsqrte z1.s, z0.s ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: fmul z2.s, z1.s, z1.s -; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 ; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s ; CHECK-NEXT: fmul z1.s, z1.s, z2.s ; CHECK-NEXT: fmul z2.s, z1.s, z1.s ; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s ; CHECK-NEXT: fmul z1.s, z1.s, z2.s -; CHECK-NEXT: fmul z1.s, z0.s, z1.s -; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv4f32( %a) ret %fsqrt @@ -153,7 +151,7 @@ ; CHECK-NEXT: frsqrte z1.d, z0.d ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: fmul z2.d, z1.d, z1.d -; CHECK-NEXT: fcmeq p0.d, p0/z, z0.d, #0.0 +; CHECK-NEXT: fcmne p0.d, p0/z, z0.d, #0.0 ; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d ; CHECK-NEXT: fmul z1.d, z1.d, z2.d ; CHECK-NEXT: fmul z2.d, z1.d, z1.d @@ -162,8 +160,7 @@ ; CHECK-NEXT: fmul z2.d, z1.d, z1.d ; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d ; CHECK-NEXT: fmul z1.d, z1.d, z2.d -; CHECK-NEXT: fmul z1.d, z0.d, z1.d -; CHECK-NEXT: sel z0.d, p0, z0.d, z1.d +; CHECK-NEXT: fmul z0.d, p0/m, z0.d, z1.d ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv2f64( %a) ret %fsqrt diff --git a/llvm/test/CodeGen/AArch64/sve-select.ll b/llvm/test/CodeGen/AArch64/sve-select.ll --- a/llvm/test/CodeGen/AArch64/sve-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-select.ll @@ -542,3 +542,82 @@ %sel = select i1 %mask, %a, %b ret %sel } + +define @fcmp_select_f32_invert_fmul( %a, %b) #0 { +; CHECK-LABEL: fcmp_select_f32_invert_fmul: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fmul = fmul %a, %b + %sel = select %p, %a, %fmul + ret %sel +} + +define @fcmp_select_f32_invert_fadd( %a, %b) { +; CHECK-LABEL: fcmp_select_f32_invert_fadd: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fadd z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fadd = fadd %a, %b + %sel = select %p, %a, %fadd + ret %sel +} + +define @fcmp_select_f32_invert_fsub( %a, %b) { +; CHECK-LABEL: fcmp_select_f32_invert_fsub: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fsub z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fsub = fsub %a, %b + %sel = select %p, %a, %fsub + ret %sel +} + +define @fcmp_select_f32_no_invert( %a, %b) { +; CHECK-LABEL: fcmp_select_f32_no_invert: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fmul = fmul %a, %b + %sel = select %p, %fmul, %a + ret %sel +} + +define @fcmp_select_f32_no_invert_double_op( %a, %b, %c, %d) { +; CHECK-LABEL: fcmp_select_f32_no_invert_double_op: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmul z2.s, z2.s, z3.s +; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, z0.s, z1.s +; CHECK-NEXT: sel z0.s, p0, z0.s, z2.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fmul1 = fmul %a, %b + %fmul2 = fmul %c, %d + %sel = select %p, %fmul1, %fmul2 + ret %sel +} + +define @fcmp_select_f32_double_op( %a, %b) { +; CHECK-LABEL: fcmp_select_f32_double_op: +; CHECK: // %bb.0: +; CHECK-NEXT: fmul z0.s, z0.s, z1.s +; CHECK-NEXT: ret + %m = fmul %a, %b + %p = fcmp oeq %m, zeroinitializer + %sel = select %p, %m, %m + ret %sel +}