Index: include/llvm/CodeGen/TargetLowering.h =================================================================== --- include/llvm/CodeGen/TargetLowering.h +++ include/llvm/CodeGen/TargetLowering.h @@ -509,6 +509,16 @@ return hasAndNotCompare(X); } + /// There are two ways to clear extreme bits (either low or high): + /// Mask: x & (-1 << y) (the instcombine canonical form) + /// Shifts: x >> y << y + /// Different targets may have different preferences. + /// Returns true if the shift variant is preferred. + virtual bool preferShiftsToClearExtremeBits(SDValue X) const { + // By default, let's assume that everyone prefers masking. + return false; + } + /// Return true if the target wants to use the optimization that /// turns ext(promotableInst1(...(promotableInstN(load)))) into /// promotedInst1(...(promotedInstN(ext(load)))). Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -409,6 +409,7 @@ SDValue foldLogicOfSetCCs(bool IsAnd, SDValue N0, SDValue N1, const SDLoc &DL); SDValue unfoldMaskedMerge(SDNode *N); + SDValue unfoldExtremeBitClearingToShifts(SDNode *N); SDValue SimplifySetCC(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond, const SDLoc &DL, bool foldBooleans); SDValue rebuildSetCC(SDValue N); @@ -4169,6 +4170,63 @@ return false; } +// Unfold +// x & (-1 'logical shift' y) +// To +// (x 'opposite logical shift' y) 'logical shift' y +// if it is better for performance. +SDValue DAGCombiner::unfoldExtremeBitClearingToShifts(SDNode *N) { + assert(N->getOpcode() == ISD::AND); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // Do we actually prefer shifts over mask? + if (!TLI.preferShiftsToClearExtremeBits(N0)) + return SDValue(); + + // Try to match (-1 '[outer] logical shift' y) + unsigned OuterShift; + unsigned InnerShift; // The opposite direction to the OuterShift. + SDValue Y; // Shift amount. + auto matchMask = [&OuterShift, &InnerShift, &Y](SDValue M) -> bool { + if (!M.hasOneUse()) + return false; + switch (OuterShift = M->getOpcode()) { + case ISD::SHL: + InnerShift = ISD::SRL; + break; + case ISD::SRL: + InnerShift = ISD::SHL; + break; + default: + return false; + } + if (!isAllOnesConstant(M->getOperand(0))) + return false; + Y = M->getOperand(1); + return true; + }; + + SDValue X; + if (matchMask(N1)) + X = N0; + else if (matchMask(N0)) + X = N1; + else + return SDValue(); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + + // tmp = x 'opposite logical shift' y + SDValue T0 = DAG.getNode(InnerShift, DL, VT, X, Y); + // ret = tmp 'logical shift' y + SDValue T1 = DAG.getNode(OuterShift, DL, VT, T0, Y); + + return T1; +} + SDValue DAGCombiner::visitAND(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -4466,6 +4524,9 @@ return BSwap; } + if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N)) + return Shifts; + return SDValue(); } Index: lib/Target/X86/X86ISelLowering.h =================================================================== --- lib/Target/X86/X86ISelLowering.h +++ lib/Target/X86/X86ISelLowering.h @@ -831,6 +831,8 @@ bool hasAndNot(SDValue Y) const override; + bool preferShiftsToClearExtremeBits(SDValue Y) const override; + bool convertSetCCLogicToBitwiseLogic(EVT VT) const override { return VT.isScalarInteger(); } Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -4785,6 +4785,18 @@ return Subtarget.hasSSE2(); } +bool X86TargetLowering::preferShiftsToClearExtremeBits(SDValue Y) const { + EVT VT = Y.getValueType(); + + // For vectors, we don't have a preference, but we probably want a mask. + if (VT.isVector()) + return false; + + // If we have BMI2's SHLX/SHRX Shifts Without Affecting Flags, we prefer them. + // There are only 32-bit and 64-bit forms for SHLX/SHRX. + return Subtarget.hasBMI2() && (VT == MVT::i32 || VT == MVT::i64); +} + MVT X86TargetLowering::hasFastEqualityCompare(unsigned NumBits) const { MVT VT = MVT::getIntegerVT(NumBits); if (isTypeLegal(VT)) Index: lib/Target/X86/X86InstrInfo.td =================================================================== --- lib/Target/X86/X86InstrInfo.td +++ lib/Target/X86/X86InstrInfo.td @@ -2561,6 +2561,15 @@ (i8 (trunc (sub 64, GR32:$lz)))), (BZHI64rm addr:$src, (INSERT_SUBREG (i64 (IMPLICIT_DEF)), GR32:$lz, sub_32bit))>; + + // x << (64 - y) >> (64 - y) + def : Pat<(srl (shl GR64:$src, (i8 (trunc (sub 64, GR64:$lz)))), + (i8 (trunc (sub 64, GR64:$lz)))), + (BZHI64rr GR64:$src, GR64:$lz)>; + def : Pat<(srl (shl (loadi64 addr:$src), (i8 (trunc (sub 64, GR64:$lz)))), + (i8 (trunc (sub 64, GR64:$lz)))), + (BZHI64rm addr:$src, GR64:$lz)>; + } // HasBMI2 multiclass bmi_pdep_pext