diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6707,13 +6707,26 @@ assert((LHS.getValueType() == RHS.getValueType()) && (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64)); + ConstantSDNode *CFVal = dyn_cast(FVal); + ConstantSDNode *CTVal = dyn_cast(TVal); + ConstantSDNode *RHSC = dyn_cast(RHS); + // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform + // into (OR (ASR lhs, N-1), 1), which requires less instructions for the + // supported types. + if (CC == ISD::SETGT && RHSC && RHSC->isAllOnesValue() && CTVal && CFVal && + CTVal->isOne() && CFVal->isAllOnesValue() && + LHS.getValueType() == TVal.getValueType()) { + EVT VT = LHS.getValueType(); + SDValue Shift = + DAG.getNode(ISD::SRA, dl, VT, LHS, + DAG.getConstant(VT.getSizeInBits() - 1, dl, VT)); + return DAG.getNode(ISD::OR, dl, VT, Shift, DAG.getConstant(1, dl, VT)); + } + unsigned Opcode = AArch64ISD::CSEL; // If both the TVal and the FVal are constants, see if we can swap them in // order to for a CSINV or CSINC out of them. - ConstantSDNode *CFVal = dyn_cast(FVal); - ConstantSDNode *CTVal = dyn_cast(TVal); - if (CTVal && CFVal && CTVal->isAllOnesValue() && CFVal->isNullValue()) { std::swap(TVal, FVal); std::swap(CTVal, CFVal); @@ -6916,7 +6929,7 @@ if (CCVal.getOpcode() == ISD::SETCC) { LHS = CCVal.getOperand(0); RHS = CCVal.getOperand(1); - CC = cast(CCVal->getOperand(2))->get(); + CC = cast(CCVal.getOperand(2))->get(); } else { LHS = CCVal; RHS = DAG.getConstant(0, DL, CCVal.getValueType()); @@ -14970,6 +14983,39 @@ SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); + // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform + // into (OR (ASR lhs, N-1), 1), which requires less instructions for the + // supported types. + SDValue SetCC = N->getOperand(0); + if (SetCC.getOpcode() == ISD::SETCC && + SetCC.getOperand(2) == DAG.getCondCode(ISD::SETGT)) { + SDValue CmpLHS = SetCC.getOperand(0); + EVT VT = CmpLHS.getValueType(); + SDNode *CmpRHS = SetCC.getOperand(1).getNode(); + SDNode *SplatLHS = N->getOperand(1).getNode(); + SDNode *SplatRHS = N->getOperand(2).getNode(); + APInt SplatLHSVal; + if (CmpLHS.getValueType() == N->getOperand(1).getValueType() && + VT.isSimple() && + is_contained( + makeArrayRef({MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, + MVT::v2i32, MVT::v4i32, MVT::v2i64}), + VT.getSimpleVT().SimpleTy) && + ISD::isConstantSplatVector(SplatLHS, SplatLHSVal) && + SplatLHSVal.isOneValue() && ISD::isConstantSplatVectorAllOnes(CmpRHS) && + ISD::isConstantSplatVectorAllOnes(SplatRHS)) { + unsigned NumElts = VT.getVectorNumElements(); + SmallVector Ops( + NumElts, DAG.getConstant(VT.getScalarSizeInBits() - 1, SDLoc(N), + VT.getScalarType())); + SDValue Val = DAG.getBuildVector(VT, SDLoc(N), Ops); + + auto Shift = DAG.getNode(ISD::SRA, SDLoc(N), VT, CmpLHS, Val); + auto Or = DAG.getNode(ISD::OR, SDLoc(N), VT, Shift, N->getOperand(1)); + return Or; + } + } + if (N0.getOpcode() != ISD::SETCC || CCVT.getVectorNumElements() != 1 || CCVT.getVectorElementType() != MVT::i1) return SDValue(); @@ -14983,10 +15029,9 @@ SDValue IfTrue = N->getOperand(1); SDValue IfFalse = N->getOperand(2); - SDValue SetCC = - DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(), - N0.getOperand(0), N0.getOperand(1), - cast(N0.getOperand(2))->get()); + SetCC = DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(), + N0.getOperand(0), N0.getOperand(1), + cast(N0.getOperand(2))->get()); return DAG.getNode(ISD::VSELECT, SDLoc(N), ResVT, SetCC, IfTrue, IfFalse); } diff --git a/llvm/test/CodeGen/AArch64/cmp-select-sign.ll b/llvm/test/CodeGen/AArch64/cmp-select-sign.ll --- a/llvm/test/CodeGen/AArch64/cmp-select-sign.ll +++ b/llvm/test/CodeGen/AArch64/cmp-select-sign.ll @@ -4,10 +4,8 @@ define i3 @sign_i3(i3 %a) { ; CHECK-LABEL: sign_i3: ; CHECK: // %bb.0: -; CHECK-NEXT: sbfx w8, w0, #0, #3 -; CHECK-NEXT: cmp w8, #0 // =0 -; CHECK-NEXT: mov w8, #1 -; CHECK-NEXT: cneg w0, w8, lt +; CHECK-NEXT: sbfx w8, w0, #2, #1 +; CHECK-NEXT: orr w0, w8, #0x1 ; CHECK-NEXT: ret %c = icmp sgt i3 %a, -1 %res = select i1 %c, i3 1, i3 -1 @@ -17,10 +15,8 @@ define i4 @sign_i4(i4 %a) { ; CHECK-LABEL: sign_i4: ; CHECK: // %bb.0: -; CHECK-NEXT: sbfx w8, w0, #0, #4 -; CHECK-NEXT: cmp w8, #0 // =0 -; CHECK-NEXT: mov w8, #1 -; CHECK-NEXT: cneg w0, w8, lt +; CHECK-NEXT: sbfx w8, w0, #3, #1 +; CHECK-NEXT: orr w0, w8, #0x1 ; CHECK-NEXT: ret %c = icmp sgt i4 %a, -1 %res = select i1 %c, i4 1, i4 -1 @@ -30,10 +26,8 @@ define i8 @sign_i8(i8 %a) { ; CHECK-LABEL: sign_i8: ; CHECK: // %bb.0: -; CHECK-NEXT: sxtb w8, w0 -; CHECK-NEXT: cmp w8, #0 // =0 -; CHECK-NEXT: mov w8, #1 -; CHECK-NEXT: cneg w0, w8, lt +; CHECK-NEXT: sbfx w8, w0, #7, #1 +; CHECK-NEXT: orr w0, w8, #0x1 ; CHECK-NEXT: ret %c = icmp sgt i8 %a, -1 %res = select i1 %c, i8 1, i8 -1 @@ -43,10 +37,8 @@ define i16 @sign_i16(i16 %a) { ; CHECK-LABEL: sign_i16: ; CHECK: // %bb.0: -; CHECK-NEXT: sxth w8, w0 -; CHECK-NEXT: cmp w8, #0 // =0 -; CHECK-NEXT: mov w8, #1 -; CHECK-NEXT: cneg w0, w8, lt +; CHECK-NEXT: sbfx w8, w0, #15, #1 +; CHECK-NEXT: orr w0, w8, #0x1 ; CHECK-NEXT: ret %c = icmp sgt i16 %a, -1 %res = select i1 %c, i16 1, i16 -1 @@ -56,9 +48,8 @@ define i32 @sign_i32(i32 %a) { ; CHECK-LABEL: sign_i32: ; CHECK: // %bb.0: -; CHECK-NEXT: cmp w0, #0 // =0 -; CHECK-NEXT: mov w8, #1 -; CHECK-NEXT: cneg w0, w8, lt +; CHECK-NEXT: asr w8, w0, #31 +; CHECK-NEXT: orr w0, w8, #0x1 ; CHECK-NEXT: ret %c = icmp sgt i32 %a, -1 %res = select i1 %c, i32 1, i32 -1 @@ -68,9 +59,8 @@ define i64 @sign_i64(i64 %a) { ; CHECK-LABEL: sign_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: cmp x0, #0 // =0 -; CHECK-NEXT: mov w8, #1 -; CHECK-NEXT: cneg x0, x8, lt +; CHECK-NEXT: asr x8, x0, #63 +; CHECK-NEXT: orr x0, x8, #0x1 ; CHECK-NEXT: ret %c = icmp sgt i64 %a, -1 %res = select i1 %c, i64 1, i64 -1 @@ -124,11 +114,9 @@ define <7 x i8> @sign_7xi8(<7 x i8> %a) { ; CHECK-LABEL: sign_7xi8: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.2d, #0xffffffffffffffff -; CHECK-NEXT: cmgt v0.8b, v0.8b, v1.8b +; CHECK-NEXT: sshr v0.8b, v0.8b, #7 ; CHECK-NEXT: movi v1.8b, #1 -; CHECK-NEXT: and v1.8b, v0.8b, v1.8b -; CHECK-NEXT: orn v0.8b, v1.8b, v0.8b +; CHECK-NEXT: orr v0.8b, v0.8b, v1.8b ; CHECK-NEXT: ret %c = icmp sgt <7 x i8> %a, %res = select <7 x i1> %c, <7 x i8> , <7 x i8> @@ -138,11 +126,9 @@ define <8 x i8> @sign_8xi8(<8 x i8> %a) { ; CHECK-LABEL: sign_8xi8: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.2d, #0xffffffffffffffff -; CHECK-NEXT: cmgt v0.8b, v0.8b, v1.8b +; CHECK-NEXT: sshr v0.8b, v0.8b, #7 ; CHECK-NEXT: movi v1.8b, #1 -; CHECK-NEXT: and v1.8b, v0.8b, v1.8b -; CHECK-NEXT: orn v0.8b, v1.8b, v0.8b +; CHECK-NEXT: orr v0.8b, v0.8b, v1.8b ; CHECK-NEXT: ret %c = icmp sgt <8 x i8> %a, %res = select <8 x i1> %c, <8 x i8> , <8 x i8> @@ -152,11 +138,9 @@ define <16 x i8> @sign_16xi8(<16 x i8> %a) { ; CHECK-LABEL: sign_16xi8: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.2d, #0xffffffffffffffff -; CHECK-NEXT: cmgt v0.16b, v0.16b, v1.16b +; CHECK-NEXT: sshr v0.16b, v0.16b, #7 ; CHECK-NEXT: movi v1.16b, #1 -; CHECK-NEXT: and v1.16b, v0.16b, v1.16b -; CHECK-NEXT: orn v0.16b, v1.16b, v0.16b +; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b ; CHECK-NEXT: ret %c = icmp sgt <16 x i8> %a, %res = select <16 x i1> %c, <16 x i8> , <16 x i8> @@ -166,11 +150,8 @@ define <3 x i32> @sign_3xi32(<3 x i32> %a) { ; CHECK-LABEL: sign_3xi32: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.2d, #0xffffffffffffffff -; CHECK-NEXT: cmgt v0.4s, v0.4s, v1.4s -; CHECK-NEXT: movi v1.4s, #1 -; CHECK-NEXT: and v1.16b, v0.16b, v1.16b -; CHECK-NEXT: orn v0.16b, v1.16b, v0.16b +; CHECK-NEXT: sshr v0.4s, v0.4s, #31 +; CHECK-NEXT: orr v0.4s, #1 ; CHECK-NEXT: ret %c = icmp sgt <3 x i32> %a, %res = select <3 x i1> %c, <3 x i32> , <3 x i32> @@ -180,11 +161,8 @@ define <4 x i32> @sign_4xi32(<4 x i32> %a) { ; CHECK-LABEL: sign_4xi32: ; CHECK: // %bb.0: -; CHECK-NEXT: movi v1.2d, #0xffffffffffffffff -; CHECK-NEXT: cmgt v0.4s, v0.4s, v1.4s -; CHECK-NEXT: movi v1.4s, #1 -; CHECK-NEXT: and v1.16b, v0.16b, v1.16b -; CHECK-NEXT: orn v0.16b, v1.16b, v0.16b +; CHECK-NEXT: sshr v0.4s, v0.4s, #31 +; CHECK-NEXT: orr v0.4s, #1 ; CHECK-NEXT: ret %c = icmp sgt <4 x i32> %a, %res = select <4 x i1> %c, <4 x i32> , <4 x i32> @@ -199,12 +177,11 @@ ; CHECK-NEXT: .cfi_def_cfa_offset 32 ; CHECK-NEXT: .cfi_offset w30, -16 ; CHECK-NEXT: movi v1.2d, #0xffffffffffffffff -; CHECK-NEXT: movi v2.4s, #1 +; CHECK-NEXT: sshr v2.4s, v0.4s, #31 ; CHECK-NEXT: cmgt v0.4s, v0.4s, v1.4s -; CHECK-NEXT: and v1.16b, v0.16b, v2.16b -; CHECK-NEXT: orn v1.16b, v1.16b, v0.16b +; CHECK-NEXT: orr v2.4s, #1 ; CHECK-NEXT: xtn v0.4h, v0.4s -; CHECK-NEXT: str q1, [sp] // 16-byte Folded Spill +; CHECK-NEXT: str q2, [sp] // 16-byte Folded Spill ; CHECK-NEXT: bl use_4xi1 ; CHECK-NEXT: ldr q0, [sp] // 16-byte Folded Reload ; CHECK-NEXT: ldr x30, [sp, #16] // 8-byte Folded Reload @@ -268,25 +245,20 @@ define <4 x i65> @sign_4xi65(<4 x i65> %a) { ; CHECK-LABEL: sign_4xi65: ; CHECK: // %bb.0: -; CHECK-NEXT: sbfx x11, x3, #0, #1 -; CHECK-NEXT: sbfx x10, x5, #0, #1 -; CHECK-NEXT: mov w12, #1 -; CHECK-NEXT: cmp x11, #0 // =0 -; CHECK-NEXT: sbfx x9, x7, #0, #1 -; CHECK-NEXT: cneg x2, x12, lt -; CHECK-NEXT: cmp x10, #0 // =0 ; CHECK-NEXT: sbfx x8, x1, #0, #1 -; CHECK-NEXT: cneg x4, x12, lt -; CHECK-NEXT: cmp x9, #0 // =0 -; CHECK-NEXT: cneg x6, x12, lt -; CHECK-NEXT: cmp x8, #0 // =0 -; CHECK-NEXT: lsr x5, x10, #63 -; CHECK-NEXT: cneg x10, x12, lt +; CHECK-NEXT: sbfx x9, x7, #0, #1 +; CHECK-NEXT: orr x6, x9, #0x1 +; CHECK-NEXT: lsr x7, x9, #63 +; CHECK-NEXT: orr x9, x8, #0x1 ; CHECK-NEXT: lsr x1, x8, #63 -; CHECK-NEXT: fmov d0, x10 +; CHECK-NEXT: fmov d0, x9 +; CHECK-NEXT: sbfx x10, x5, #0, #1 +; CHECK-NEXT: sbfx x11, x3, #0, #1 ; CHECK-NEXT: mov v0.d[1], x1 +; CHECK-NEXT: orr x2, x11, #0x1 ; CHECK-NEXT: lsr x3, x11, #63 -; CHECK-NEXT: lsr x7, x9, #63 +; CHECK-NEXT: orr x4, x10, #0x1 +; CHECK-NEXT: lsr x5, x10, #63 ; CHECK-NEXT: fmov x0, d0 ; CHECK-NEXT: ret %c = icmp sgt <4 x i65> %a,