Index: llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1908,25 +1908,40 @@ Constant *C1, *C2; const APInt *C3 = C; Value *X; - if (C3->isPowerOf2() && - match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)), - m_ImmConstant(C2)))) && - match(C1, m_Power2())) { - Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1); + if (C3->isPowerOf2()) { Constant *Log2C3 = ConstantInt::get(Ty, C3->countTrailingZeros()); - Constant *LshrC = ConstantExpr::getAdd(C2, Log2C3); - KnownBits KnownLShrc = computeKnownBits(LshrC, 0, nullptr); - if (KnownLShrc.getMaxValue().ult(Width)) { - // iff C1,C3 is pow2 and C2 + cttz(C3) < BitWidth: - // ((C1 << X) >> C2) & C3 -> X == (cttz(C3)+C2-cttz(C1)) ? C3 : 0 - Constant *CmpC = ConstantExpr::getSub(LshrC, Log2C1); - Value *Cmp = Builder.CreateICmpEQ(X, CmpC); - return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3), - ConstantInt::getNullValue(Ty)); + if (match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)), + m_ImmConstant(C2)))) && + match(C1, m_Power2())) { + Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1); + Constant *LshrC = ConstantExpr::getAdd(C2, Log2C3); + KnownBits KnownLShrc = computeKnownBits(LshrC, 0, nullptr); + if (KnownLShrc.getMaxValue().ult(Width)) { + // iff C1,C3 is pow2 and C2 + cttz(C3) < BitWidth: + // ((C1 << X) >> C2) & C3 -> X == (cttz(C3)+C2-cttz(C1)) ? C3 : 0 + Constant *CmpC = ConstantExpr::getSub(LshrC, Log2C1); + Value *Cmp = Builder.CreateICmpEQ(X, CmpC); + return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3), + ConstantInt::getNullValue(Ty)); + } + } + + if (match(Op0, m_OneUse(m_Shl(m_LShr(m_ImmConstant(C1), m_Value(X)), + m_ImmConstant(C2)))) && + match(C1, m_Power2())) { + Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1); + Constant *Cmp = + ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2); + if (Cmp->isZeroValue()) { + // iff C1,C3 is pow2 and Log2(C3) >= C2: + // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0 + Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1); + Constant *CmpC = ConstantExpr::getSub(ShlC, Log2C3); + Value *Cmp = Builder.CreateICmpEQ(X, CmpC); + return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3), + ConstantInt::getNullValue(Ty)); + } } - // TODO: Symmetrical case - // iff C1,C3 is pow2 and Log2(C3) >= C2: - // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0 } }