diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17196,12 +17196,49 @@ DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3)); } +// (vselect (setcc (a) (n)) (a) (op (a) (b))) +// => (vselect (not setcc (a) (n)) (op (a) (b)) (a)) +static Optional tryInvertVSelectWithSetCC(SDNode *N, + SelectionDAG &DAG) { + auto SelectA = N->getOperand(1); + auto SelectB = N->getOperand(2); + auto NTy = N->getValueType(0); + + SDValue SetCC = N->getOperand(0); + if (SetCC.getOpcode() != ISD::SETCC || !NTy.isScalableVector() || + !SetCC.hasOneUse()) + return None; + auto SetCCOp0 = SetCC.getOperand(0); + + switch (SelectB.getOpcode()) { + default: + return None; + case ISD::FMUL: + case ISD::FSUB: + case ISD::FADD: + break; + } + if (SelectA != SelectB.getOperand(0)) + return None; + + ISD::CondCode CC = cast(SetCC->getOperand(2))->get(); + auto InverseSetCC = DAG.getSetCC( + SDLoc(SetCC), SetCC.getValueType(), SetCCOp0, SetCC.getOperand(1), + ISD::getSetCCInverse(CC, SetCC.getValueType())); + + return DAG.getNode(ISD::VSELECT, SDLoc(N), NTy, + {InverseSetCC, SelectB, SelectA}); +} + // vselect (v1i1 setcc) -> // vselect (v1iXX setcc) (XX is the size of the compared operand type) // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as // condition. If it can legalize "VSELECT v1i1" correctly, no need to combine // such VSELECT. static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) { + if (auto InvertResult = tryInvertVSelectWithSetCC(N, DAG)) + return InvertResult.getValue(); + SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); diff --git a/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll b/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll --- a/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll +++ b/llvm/test/CodeGen/AArch64/sve-fp-reciprocal.ll @@ -95,14 +95,13 @@ ; CHECK-NEXT: frsqrte z1.h, z0.h ; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: fmul z2.h, z1.h, z1.h -; CHECK-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0 +; CHECK-NEXT: fcmne p0.h, p0/z, z0.h, #0.0 ; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h ; CHECK-NEXT: fmul z1.h, z1.h, z2.h ; CHECK-NEXT: fmul z2.h, z1.h, z1.h ; CHECK-NEXT: frsqrts z2.h, z0.h, z2.h ; CHECK-NEXT: fmul z1.h, z1.h, z2.h -; CHECK-NEXT: fmul z1.h, z0.h, z1.h -; CHECK-NEXT: sel z0.h, p0, z0.h, z1.h +; CHECK-NEXT: fmul z0.h, p0/m, z0.h, z1.h ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv8f16( %a) ret %fsqrt @@ -124,14 +123,13 @@ ; CHECK-NEXT: frsqrte z1.s, z0.s ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: fmul z2.s, z1.s, z1.s -; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 ; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s ; CHECK-NEXT: fmul z1.s, z1.s, z2.s ; CHECK-NEXT: fmul z2.s, z1.s, z1.s ; CHECK-NEXT: frsqrts z2.s, z0.s, z2.s ; CHECK-NEXT: fmul z1.s, z1.s, z2.s -; CHECK-NEXT: fmul z1.s, z0.s, z1.s -; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv4f32( %a) ret %fsqrt @@ -153,7 +151,7 @@ ; CHECK-NEXT: frsqrte z1.d, z0.d ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: fmul z2.d, z1.d, z1.d -; CHECK-NEXT: fcmeq p0.d, p0/z, z0.d, #0.0 +; CHECK-NEXT: fcmne p0.d, p0/z, z0.d, #0.0 ; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d ; CHECK-NEXT: fmul z1.d, z1.d, z2.d ; CHECK-NEXT: fmul z2.d, z1.d, z1.d @@ -162,8 +160,7 @@ ; CHECK-NEXT: fmul z2.d, z1.d, z1.d ; CHECK-NEXT: frsqrts z2.d, z0.d, z2.d ; CHECK-NEXT: fmul z1.d, z1.d, z2.d -; CHECK-NEXT: fmul z1.d, z0.d, z1.d -; CHECK-NEXT: sel z0.d, p0, z0.d, z1.d +; CHECK-NEXT: fmul z0.d, p0/m, z0.d, z1.d ; CHECK-NEXT: ret %fsqrt = call fast @llvm.sqrt.nxv2f64( %a) ret %fsqrt diff --git a/llvm/test/CodeGen/AArch64/sve-select.ll b/llvm/test/CodeGen/AArch64/sve-select.ll --- a/llvm/test/CodeGen/AArch64/sve-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-select.ll @@ -542,3 +542,111 @@ %sel = select i1 %mask, %a, %b ret %sel } + +define @select_f32_invert_fmul( %a, %b) #0 { +; CHECK-LABEL: select_f32_invert_fmul: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fmul = fmul %a, %b + %sel = select %p, %a, %fmul + ret %sel +} + +define @select_f32_invert_fadd( %a, %b) { +; CHECK-LABEL: select_f32_invert_fadd: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fadd z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fadd = fadd %a, %b + %sel = select %p, %a, %fadd + ret %sel +} + +define @select_f32_invert_fsub( %a, %b) { +; CHECK-LABEL: select_f32_invert_fsub: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmne p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fsub z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fsub = fsub %a, %b + %sel = select %p, %a, %fsub + ret %sel +} + +define @select_f32_no_invert_op_lhs( %a, %b) { +; CHECK-LABEL: select_f32_no_invert_op_lhs: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fmul = fmul %a, %b + %sel = select %p, %fmul, %a + ret %sel +} + +define @select_f32_no_invert_2_op( %a, %b, %c, %d) { +; CHECK-LABEL: select_f32_no_invert_2_op: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fmul z2.s, z2.s, z3.s +; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fmul z0.s, z0.s, z1.s +; CHECK-NEXT: sel z0.s, p0, z0.s, z2.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fmul1 = fmul %a, %b + %fmul2 = fmul %c, %d + %sel = select %p, %fmul1, %fmul2 + ret %sel +} + +define @select_f32_no_invert_equal_ops( %a, %b) { +; CHECK-LABEL: select_f32_no_invert_equal_ops: +; CHECK: // %bb.0: +; CHECK-NEXT: fmul z0.s, z0.s, z1.s +; CHECK-NEXT: ret + %m = fmul %a, %b + %p = fcmp oeq %m, zeroinitializer + %sel = select %p, %m, %m + ret %sel +} + +define @select_f32_no_invert_fmul_two_setcc_uses( %a, %b, %c, i32 %len) #0 { +; CHECK-LABEL: select_f32_no_invert_fmul_two_setcc_uses: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fadd z1.s, z0.s, z1.s +; CHECK-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s +; CHECK-NEXT: mov z0.s, p0/m, z2.s +; CHECK-NEXT: ret + %p = fcmp oeq %a, zeroinitializer + %fadd = fadd %a, %b + %sel = select %p, %a, %fadd + %sel2 = select %p, %c, %sel + ret %sel2 +} + +define <4 x float> @select_f32_no_invert_not_scalable(<4 x float> %a, <4 x float> %b) #0 { +; CHECK-LABEL: select_f32_no_invert_not_scalable: +; CHECK: // %bb.0: +; CHECK-NEXT: fcmeq v2.4s, v0.4s, #0.0 +; CHECK-NEXT: fmul v1.4s, v0.4s, v1.4s +; CHECK-NEXT: bif v0.16b, v1.16b, v2.16b +; CHECK-NEXT: ret + %p = fcmp oeq <4 x float> %a, zeroinitializer + %fmul = fmul <4 x float> %a, %b + %sel = select <4 x i1> %p, <4 x float> %a, <4 x float> %fmul + ret <4 x float> %sel +}