diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3163,6 +3163,75 @@ return nullptr; } +static Instruction *foldBitCeil(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Type *SelType = SI.getType(); + unsigned BitWidth = SelType->getScalarSizeInBits(); + + ICmpInst::Predicate Pred; + Value *Cond0; + uint64_t Cond1; + if (!match(SI.getCondition(), + m_ICmp(Pred, m_Value(Cond0), m_ConstantInt(Cond1))) || + !match(SI.getFalseValue(), m_One())) + return nullptr; + + Value *Ctlz; + if (!match(SI.getTrueValue(), + m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), + m_Value(Ctlz))))))) + return nullptr; + + Value *CtlzOp; + if (!match(Ctlz, + m_Intrinsic(m_Value(CtlzOp), m_SpecificInt(0)))) + return nullptr; + + // Return true if CTLZ returns 0 or the operand size (but no value in between) + // for the entire range [LowerBound, UpperBound], which may involve an + // unsigned overflow. + auto DoesCTLZProduceZeroOrOpSize = [](uint64_t LowerBound, + uint64_t UpperBound) { + uint64_t SignBit = uint64_t(1) << 63; + + // An unsigned overflow? + if (UpperBound < LowerBound) + // If an unsigned overflow occurs, then LowerBound must have the MSB set + // and UpperBound must be 0. + return (LowerBound & SignBit) != 0 and UpperBound == 0; + + // No unsigned overflow. + return (LowerBound == 0 && UpperBound == 0) || (LowerBound & SignBit) != 0; + }; + + // Check to see if CTLZ evaluates to either 0 or BitWidth when the select + // condition is false. If so, the shift amount -ctlz(...) & (BitWidth-1) + // would be 0, so the shift result would be 1, which in turn allows us to get + // rid of the select altogether. + uint64_t Addend; + if (Pred != ICmpInst::ICMP_UGT || + !match(CtlzOp, m_Add(m_Specific(Cond0), m_ConstantInt(Addend)))) + return nullptr; + + // Cond == false means Cond0 u<= Cond1, where CtlzOp == Cond0 + Addend. + // CtlzOp ranges between [0 + Addend, Cond1 + Addend], where the range might + // wrap around. + uint64_t LowerBound = SignExtend64(0 + Addend, BitWidth); + uint64_t UpperBound = SignExtend64(Cond1 + Addend, BitWidth); + if (!DoesCTLZProduceZeroOrOpSize(LowerBound, UpperBound)) + return nullptr; + + // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a + // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth + // is an integer constant. Masking with BitWidth-1 comes free on some + // hardware as part of the shift instruction. + Value *Neg = Builder.CreateSub(ConstantInt::getNullValue(SelType), Ctlz); + Value *Masked = + Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1)); + return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1), + Masked); +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3590,5 +3659,8 @@ if (sinkNotIntoOtherHandOfLogicalOp(SI)) return &SI; + if (Instruction *I = foldBitCeil(SI, Builder)) + return I; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/bit_ceil.ll b/llvm/test/Transforms/InstCombine/bit_ceil.ll --- a/llvm/test/Transforms/InstCombine/bit_ceil.ll +++ b/llvm/test/Transforms/InstCombine/bit_ceil.ll @@ -5,10 +5,9 @@ ; CHECK-LABEL: @bit_ceil_32( ; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X:%.*]], -1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0:![0-9]+]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]] -; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i32 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]] ; CHECK-NEXT: ret i32 [[SEL]] ; %dec = add i32 %x, -1 @@ -24,10 +23,9 @@ ; CHECK-LABEL: @bit_ceil_64( ; CHECK-NEXT: [[DEC:%.*]] = add i64 [[X:%.*]], -1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[DEC]], i1 false), !range [[RNG1:![0-9]+]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i64 64, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 1, [[SUB]] -; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i64 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i64 [[SHL]], i64 1 +; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i64 0, [[CTLZ]] +; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[TMP1]], 63 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i64 1, [[TMP2]] ; CHECK-NEXT: ret i64 [[SEL]] ; %dec = add i64 %x, -1