diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3163,6 +3163,169 @@ return nullptr; } +// Return true if we can safely remove the select instruction for std::bit_ceil +// pattern. +static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, + uint64_t Cond1, Value *CtlzOp, unsigned BitWidth) { + // Return true if CTLZ returns 0 or the operand size (but no value in between) + // for the entire range [LowerBound, UpperBound], which may involve an + // unsigned overflow. + auto DoesCTLZProduceZeroOrOpSize = [](uint64_t LowerBound, + uint64_t UpperBound) { + uint64_t SignBit = uint64_t(1) << 63; + + // An unsigned overflow? + if (UpperBound < LowerBound) + // If an unsigned overflow occurs, then LowerBound must have the MSB set + // and UpperBound must be 0. + return (LowerBound & SignBit) != 0 and UpperBound == 0; + + // No unsigned overflow. + return (LowerBound == 0 && UpperBound == 0) || (LowerBound & SignBit) != 0; + }; + + // The challenge in recognizing std::bit_ceil(X) is that the operand is used + // for the CTLZ proper and select condition, each possibly with some + // operation like add and sub. + // + // Our aim is to make sure that -ctlz & ((1 << BitWidth) - 1) == 0 even when + // the select instruction would select 1, which allows us to get rid of the + // select instruction. + // + // To see if we can do so, we do some symbolic execution. Specifically, we + // compute the range of values that Cond0 could take when Cond == false. Then + // we successively transform the range until we obtain the range of values + // that CtlzOp could take. + // + // Conceptually, we follow the def-use chain backward from Cond0 while + // transforming the range for Cond0 until we meet the common ancestor of Cond0 + // and CtlzOp. Then we follow the def-use chain forward until we obtain the + // range for CtlzOp. That said, we only follow at most one ancestor from + // Cond0. Likewise, we only follow at most one ancestor from CtrlOp. + + Value *X = nullptr; + if (match(CtlzOp, m_Add(m_Value(X), m_ConstantInt())) || + match(CtlzOp, m_Sub(m_ConstantInt(), m_Value(X))) || + match(CtlzOp, m_Not(m_Value(X)))) { + // We'll stop following the def-use chain when we encounter X. + } + + // The Value that LowerBound and UpperBound below pertain to. + Value *RangeOp = Cond0; + + // The lower and upper bound of RangeOp. Note that the range may wrap around + // like [-1, 0]. + uint64_t LowerBound, UpperBound; + + switch (Pred) { + case ICmpInst::ICMP_UGT: + // Cond == false means Cond0 u<= Cond1. + LowerBound = 0; + UpperBound = Cond1; + break; + case ICmpInst::ICMP_ULT: + // Cond == false means Cond0 u>= Cond1. + LowerBound = Cond1; + UpperBound = ~uint64_t(0); + break; + default: + return false; + } + + // Get to the parent of RangeOp if RangeOp doesn't show up in the ancestor + // chain of CtlzOp, + if (RangeOp != CtlzOp && RangeOp != X) { + Value *Y; + uint64_t C; + if (match(Cond0, m_Add(m_Value(Y), m_ConstantInt(C)))) { + LowerBound -= C; + UpperBound -= C; + RangeOp = Y; + } + } + + if (RangeOp == CtlzOp) { + // Good. We already know the range for CtlzOp. + } else if (RangeOp == X) { + uint64_t C; + if (match(CtlzOp, m_Add(m_Specific(X), m_ConstantInt(C)))) { + LowerBound = LowerBound + C; + UpperBound = UpperBound + C; + RangeOp = CtlzOp; + } else if (match(CtlzOp, m_Sub(m_ConstantInt(C), m_Specific(X)))) { + // LowerBound and UpperBound get swapped because of Sub. + LowerBound = C - UpperBound; + UpperBound = C - LowerBound; + RangeOp = CtlzOp; + } else if (match(CtlzOp, m_Not(m_Specific(X)))) { + // LowerBound and UpperBound get swapped because of Not. + LowerBound = ~UpperBound; + UpperBound = ~LowerBound; + RangeOp = CtlzOp; + } else { + return false; + } + } else { + return false; + } + + assert(RangeOp == CtlzOp); + return DoesCTLZProduceZeroOrOpSize(SignExtend64(LowerBound, BitWidth), + SignExtend64(UpperBound, BitWidth)); +} + +// Transform the std::bit_ceil(X) pattern like: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %sub = sub i32 32, %ctlz +// %shl = shl i32 1, %sub +// %ugt = icmp ugt i32 %x, 1 +// %sel = select i1 %ugt, i32 %shl, i32 1 +// +// into: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %neg = sub i32 0, %ctlz +// %masked = and i32 %ctlz, 31 +// %shl = shl i32 1, %sub +// +// We handle some variations of the input operand like std::bit_ceil(X + 1). +static Instruction *foldBitCeil(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Type *SelType = SI.getType(); + unsigned BitWidth = SelType->getScalarSizeInBits(); + + // We use uint64_t below, so don't accept anything wider. + if (BitWidth > 64) + return nullptr; + + ICmpInst::Predicate Pred; + uint64_t Cond1; + Value *Cond0, *Ctlz, *CtlzOp; + if (!match(SI.getCondition(), + m_ICmp(Pred, m_Value(Cond0), m_ConstantInt(Cond1))) || + !match(SI.getFalseValue(), m_One()) || + !match(SI.getTrueValue(), + m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), + m_Value(Ctlz)))))) || + !match(Ctlz, + m_Intrinsic(m_Value(CtlzOp), m_SpecificInt(0))) || + !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth)) + return nullptr; + + // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a + // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth + // is an integer constant. Masking with BitWidth-1 comes free on some + // hardware as part of the shift instruction. + Value *Neg = Builder.CreateSub(ConstantInt::getNullValue(SelType), Ctlz); + Value *Masked = + Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1)); + return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1), + Masked); +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3590,5 +3753,8 @@ if (sinkNotIntoOtherHandOfLogicalOp(SI)) return &SI; + if (Instruction *I = foldBitCeil(SI, Builder)) + return I; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/bit_ceil.ll b/llvm/test/Transforms/InstCombine/bit_ceil.ll --- a/llvm/test/Transforms/InstCombine/bit_ceil.ll +++ b/llvm/test/Transforms/InstCombine/bit_ceil.ll @@ -6,10 +6,9 @@ ; CHECK-LABEL: @bit_ceil_32( ; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X:%.*]], -1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0:![0-9]+]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]] -; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i32 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]] ; CHECK-NEXT: ret i32 [[SEL]] ; %dec = add i32 %x, -1 @@ -26,10 +25,9 @@ ; CHECK-LABEL: @bit_ceil_64( ; CHECK-NEXT: [[DEC:%.*]] = add i64 [[X:%.*]], -1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[DEC]], i1 false), !range [[RNG1:![0-9]+]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i64 64, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 1, [[SUB]] -; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i64 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i64 [[SHL]], i64 1 +; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i64 0, [[CTLZ]] +; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[TMP1]], 63 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i64 1, [[TMP2]] ; CHECK-NEXT: ret i64 [[SEL]] ; %dec = add i64 %x, -1 @@ -47,11 +45,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], -2 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]] -; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], -3 -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[ADD]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; entry: @@ -69,11 +65,9 @@ define i32 @bit_ceil_32_plus_1(i32 %x) { ; CHECK-LABEL: @bit_ceil_32_plus_1( ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[X:%.*]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]] -; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X]], -1 -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[DEC]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]] ; CHECK-NEXT: ret i32 [[SEL]] ; %ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 false) @@ -90,10 +84,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], 1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]] -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[X]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; entry: @@ -112,11 +105,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[X:%.*]], -1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]] -; CHECK-NEXT: [[NOTSUB:%.*]] = add i32 [[X]], -1 -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[NOTSUB]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; entry: @@ -136,10 +127,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[SUB:%.*]] = sub i32 -2, [[X:%.*]] ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]] -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[X]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; entry: