diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -2554,6 +2554,12 @@ // GEP to make the GEP fit into the addressing mode and can be sunk into the // same blocks of its users. virtual bool shouldConsiderGEPOffsetSplit() const { return false; } + + // Return the shift amount threshold for profitable transforms into shifts. Transforms + // that will result in shifts above the returned value will be discarded. + virtual unsigned getShiftAmountThreshold(EVT VT) const { + return VT.getSizeInBits(); + } //===--------------------------------------------------------------------===// // Runtime Library hooks diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9491,6 +9491,8 @@ // sext i1 (setgt iN X, -1) --> sra (not X), (N - 1) // zext i1 (setgt iN X, -1) --> srl (not X), (N - 1) SDLoc DL(N); + // FIX ME: should we perform TLI.getShiftAmountThreshold checks for + // shift amounts of VT.getSizeInBits()-1 ? SDValue NotX = DAG.getNOT(DL, X, VT); SDValue ShiftAmount = DAG.getConstant(VT.getSizeInBits() - 1, DL, VT); auto ShiftOpcode = N->getOpcode() == ISD::SIGN_EXTEND ? ISD::SRA : ISD::SRL; @@ -19891,22 +19893,28 @@ auto *N2C = dyn_cast(N2.getNode()); if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) { unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1; - SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy); - SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt); - AddToWorklist(Shift.getNode()); - - if (XType.bitsGT(AType)) { - Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift); + if ( ShCt <= TLI.getShiftAmountThreshold(ShiftAmtTy)) { + SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy); + SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt); AddToWorklist(Shift.getNode()); - } - if (CC == ISD::SETGT) - Shift = DAG.getNOT(DL, Shift, AType); + if (XType.bitsGT(AType)) { + Shift = DAG.getNode(ISD::TRUNCATE, DL, AType, Shift); + AddToWorklist(Shift.getNode()); + } + + if (CC == ISD::SETGT) + Shift = DAG.getNOT(DL, Shift, AType); - return DAG.getNode(ISD::AND, DL, AType, Shift, N2); + return DAG.getNode(ISD::AND, DL, AType, Shift, N2); + } } - SDValue ShiftAmt = DAG.getConstant(XType.getSizeInBits() - 1, DL, ShiftAmtTy); + unsigned ShCt = XType.getSizeInBits() - 1; + if ( ShCt > TLI.getShiftAmountThreshold(ShiftAmtTy)) + return SDValue(); + + SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy); SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt); AddToWorklist(Shift.getNode()); @@ -20023,19 +20031,21 @@ if (ConstAndRHS && ConstAndRHS->getAPIntValue().countPopulation() == 1) { // Shift the tested bit over the sign bit. const APInt &AndMask = ConstAndRHS->getAPIntValue(); - SDValue ShlAmt = - DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS), + if (AndMask.getBitWidth() - 1 <= TLI.getShiftAmountThreshold(VT)) { + SDValue ShlAmt = + DAG.getConstant(AndMask.countLeadingZeros(), SDLoc(AndLHS), getShiftAmountTy(AndLHS.getValueType())); - SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt); + SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt); - // Now arithmetic right shift it all the way over, so the result is either - // all-ones, or zero. - SDValue ShrAmt = - DAG.getConstant(AndMask.getBitWidth() - 1, SDLoc(Shl), - getShiftAmountTy(Shl.getValueType())); - SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt); + // Now arithmetic right shift it all the way over, so the result is either + // all-ones, or zero. + SDValue ShrAmt = + DAG.getConstant(AndMask.getBitWidth() - 1, SDLoc(Shl), + getShiftAmountTy(Shl.getValueType())); + SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt); - return DAG.getNode(ISD::AND, DL, VT, Shr, N3); + return DAG.getNode(ISD::AND, DL, VT, Shr, N3); + } } } @@ -20046,7 +20056,8 @@ if ((Fold || Swap) && TLI.getBooleanContents(CmpOpVT) == TargetLowering::ZeroOrOneBooleanContent && - (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT))) { + (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, CmpOpVT)) && + TLI.isOperationLegal(ISD::SHL, N2.getValueType())) { if (Swap) { CC = ISD::getSetCCInverse(CC, CmpOpVT.isInteger()); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -3203,6 +3203,7 @@ // Back to non-vector simplifications. // TODO: Can we do these for vector splats? if (auto *N1C = dyn_cast(N1.getNode())) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); const APInt &C1 = N1C->getAPIntValue(); // Fold bit comparisons when we can. @@ -3216,7 +3217,8 @@ !DCI.isBeforeLegalize()); if (Cond == ISD::SETNE && C1 == 0) {// (X & 8) != 0 --> (X & 8) >> 3 // Perform the xform if the AND RHS is a single bit. - if (AndRHS->getAPIntValue().isPowerOf2()) { + if (AndRHS->getAPIntValue().isPowerOf2() && + AndRHS->getAPIntValue().logBase2() <= TLI.getShiftAmountThreshold(ShiftTy)) { return DAG.getNode(ISD::TRUNCATE, dl, VT, DAG.getNode(ISD::SRL, dl, N0.getValueType(), N0, DAG.getConstant(AndRHS->getAPIntValue().logBase2(), dl, @@ -3225,7 +3227,7 @@ } else if (Cond == ISD::SETEQ && C1 == AndRHS->getAPIntValue()) { // (X & 8) == 8 --> (X & 8) >> 3 // Perform the xform if C1 is a single bit. - if (C1.isPowerOf2()) { + if (C1.isPowerOf2() && C1.logBase2() <= TLI.getShiftAmountThreshold(ShiftTy)) { return DAG.getNode(ISD::TRUNCATE, dl, VT, DAG.getNode(ISD::SRL, dl, N0.getValueType(), N0, DAG.getConstant(C1.logBase2(), dl,