Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -21674,11 +21674,11 @@ if (VT.getScalarSizeInBits() < 16) return false; - if (VT.is512BitVector() && + if (VT.is512BitVector() && Subtarget.hasAVX512() && (VT.getScalarSizeInBits() > 16 || Subtarget.hasBWI())) return true; - bool LShift = VT.is128BitVector() || + bool LShift = (VT.is128BitVector() && Subtarget.hasSSE2()) || (VT.is256BitVector() && Subtarget.hasInt256()); bool AShift = LShift && (Subtarget.hasAVX512() || @@ -26645,6 +26645,15 @@ return Tmp; } + case X86ISD::VSRAI: { + SDValue Src = Op.getOperand(0); + unsigned Tmp = DAG.ComputeNumSignBits(Src, Depth + 1); + unsigned VTBits = Op.getValueType().getScalarSizeInBits(); + APInt ShiftVal = cast(Op.getOperand(1))->getAPIntValue(); + ShiftVal += Tmp; + return ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue(); + } + case X86ISD::PCMPGT: case X86ISD::PCMPEQ: case X86ISD::CMPP: @@ -31050,13 +31059,14 @@ bool LogicalShift = X86ISD::VSHLI == Opcode || X86ISD::VSRLI == Opcode; EVT VT = N->getValueType(0); SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); unsigned NumBitsPerElt = VT.getScalarSizeInBits(); assert(VT == N0.getValueType() && (NumBitsPerElt % 8) == 0 && "Unexpected value type"); // Out of range logical bit shifts are guaranteed to be zero. // Out of range arithmetic bit shifts splat the sign bit. - APInt ShiftVal = cast(N->getOperand(1))->getAPIntValue(); + APInt ShiftVal = cast(N1)->getAPIntValue(); if (ShiftVal.zextOrTrunc(8).uge(NumBitsPerElt)) { if (LogicalShift) return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, SDLoc(N)); @@ -31072,6 +31082,13 @@ if (ISD::isBuildVectorAllZeros(N0.getNode())) return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, SDLoc(N)); + // fold (VSRLI (VSRAI X, Y), 31) -> (VSRLI X, 31). + // This VSRLI only looks at the sign bit, which is unmodified by VSRAI. + // TODO - support other srl opcodes. + if (Opcode == X86ISD::VSRLI && (ShiftVal + 1) == NumBitsPerElt && + N0.getOpcode() == X86ISD::VSRAI) + return DAG.getNode(X86ISD::VSRLI, SDLoc(N), VT, N0.getOperand(0), N1); + // We can decode 'whole byte' logical bit shifts as shuffles. if (LogicalShift && (ShiftVal.getZExtValue() % 8) == 0) { SDValue Op(N, 0); @@ -31367,38 +31384,34 @@ return SDValue(); } -/// If this is a PCMPEQ or PCMPGT result that is bitwise-anded with 1 (this is -/// the x86 lowering of a SETCC + ZEXT), replace the 'and' with a shift-right to -/// eliminate loading the vector constant mask value. This relies on the fact -/// that a PCMP always creates an all-ones or all-zeros bitmask per element. -static SDValue combinePCMPAnd1(SDNode *N, SelectionDAG &DAG) { +/// If this is a zero/all-bits result that is bitwise-anded with a low bits mask. +/// (Mask == 1 for the x86 lowering of a SETCC + ZEXT), replace the 'and' with a +/// shift-right to eliminate loading the vector constant mask value. +static SDValue combineAndMaskToShift(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDValue Op0 = peekThroughBitcasts(N->getOperand(0)); SDValue Op1 = peekThroughBitcasts(N->getOperand(1)); + EVT VT0 = Op0.getValueType(); + EVT VT1 = Op1.getValueType(); - // TODO: Use AssertSext to mark any nodes that have the property of producing - // all-ones or all-zeros. Then check for that node rather than particular - // opcodes. - if (Op0.getOpcode() != X86ISD::PCMPEQ && Op0.getOpcode() != X86ISD::PCMPGT) + if (VT0 != VT1 || !VT0.isSimple() || !VT0.isInteger()) return SDValue(); - // The existence of the PCMP node guarantees that we have the required SSE2 or - // AVX2 for a shift of this vector type, but there is no vector shift by - // immediate for a vector with byte elements (PSRLB). 512-bit vectors use the - // masked compare nodes, so they should not make it here. - EVT VT0 = Op0.getValueType(); - EVT VT1 = Op1.getValueType(); - unsigned EltBitWidth = VT0.getScalarSizeInBits(); - if (VT0 != VT1 || EltBitWidth == 8) + APInt SplatVal; + if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal) || + !APIntOps::isMask(SplatVal)) return SDValue(); - assert(VT0.getSizeInBits() == 128 || VT0.getSizeInBits() == 256); + if (!SupportedVectorShiftWithImm(VT0.getSimpleVT(), Subtarget, ISD::SRL)) + return SDValue(); - APInt SplatVal; - if (!ISD::isConstantSplatVector(Op1.getNode(), SplatVal) || SplatVal != 1) + unsigned EltBitWidth = VT0.getScalarSizeInBits(); + if (EltBitWidth == 1 || EltBitWidth != DAG.ComputeNumSignBits(Op0)) return SDValue(); SDLoc DL(N); - SDValue ShAmt = DAG.getConstant(EltBitWidth - 1, DL, MVT::i8); + unsigned ShiftVal = SplatVal.countTrailingOnes(); + SDValue ShAmt = DAG.getConstant(EltBitWidth - ShiftVal, DL, MVT::i8); SDValue Shift = DAG.getNode(X86ISD::VSRLI, DL, VT0, Op0, ShAmt); return DAG.getBitcast(N->getValueType(0), Shift); } @@ -31418,7 +31431,7 @@ if (SDValue R = combineANDXORWithAllOnesIntoANDNP(N, DAG)) return R; - if (SDValue ShiftRight = combinePCMPAnd1(N, DAG)) + if (SDValue ShiftRight = combineAndMaskToShift(N, DAG, Subtarget)) return ShiftRight; EVT VT = N->getValueType(0); Index: test/CodeGen/X86/combine-and.ll =================================================================== --- test/CodeGen/X86/combine-and.ll +++ test/CodeGen/X86/combine-and.ll @@ -253,8 +253,7 @@ define <8 x i16> @ashr_mask1_v8i16(<8 x i16> %a0) { ; CHECK-LABEL: ashr_mask1_v8i16: ; CHECK: # BB#0: -; CHECK-NEXT: psraw $15, %xmm0 -; CHECK-NEXT: pand {{.*}}(%rip), %xmm0 +; CHECK-NEXT: psrlw $15, %xmm0 ; CHECK-NEXT: retq %1 = ashr <8 x i16> %a0, %2 = and <8 x i16> %1, @@ -265,7 +264,7 @@ ; CHECK-LABEL: ashr_mask7_v4i32: ; CHECK: # BB#0: ; CHECK-NEXT: psrad $31, %xmm0 -; CHECK-NEXT: pand {{.*}}(%rip), %xmm0 +; CHECK-NEXT: psrld $29, %xmm0 ; CHECK-NEXT: retq %1 = ashr <4 x i32> %a0, %2 = and <4 x i32> %1,