Index: llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -583,6 +583,38 @@ return new ZExtInst(ICmpNeZero, SelType); } +/// We want to turn: +/// (select (icmp eq (and X, C), 0), 0, (shl (nsw/nuw) X, K)); +/// iff C is a mask and the number of its leading zeros is equal to K +/// into: +/// shl X, K +static Value *foldSelectICmpAndAnd(const ICmpInst *Cmp, + Value *TVal, Value *FVal) { + ICmpInst::Predicate Pred; + Value *AndVal; + if (!match(ICmp, m_ICmp(Pred, m_Value(AndVal), m_Zero()))) + return nullptr; + + Value *X; + const APInt *K; + if (Pred != ICmpInst::ICMP_EQ || + !match(AndVal, m_And(m_Specific(X), m_APInt(C))) || + !match(TVal, m_Zero()) || + !match(FVal, m_Shl(m_Value(X), m_APInt(K)))) + return nullptr; + + if (C->isMask() || (int64_t)C->countLeadingZeros() != K->getSExtValue()) + return nullptr; + + auto *FI = dyn_cast(FVal); + if (!FI) + return nullptr; + + FI->setHasNoSignedWrap(false); + FI->setHasNoUnsignedWrap(false); + return FVal; +} + /// We want to turn: /// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1 /// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0 @@ -1806,10 +1838,14 @@ } } + if (Instruction *V = foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) return V; + if (Value *V = foldSelectICmpAndZeroShl(ICI, TrueVal, FalseVal)) + return replaceInstUsesWith(SI, V); + if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) return V;