diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3163,6 +3163,134 @@ return nullptr; } +// Return true if we can safely remove the select instruction for std::bit_ceil +// pattern. +static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, + const APInt *Cond1, Value *CtlzOp, + unsigned BitWidth) { + // The challenge in recognizing std::bit_ceil(X) is that the operand is used + // for the CTLZ proper and select condition, each possibly with some + // operation like add and sub. + // + // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the + // select instruction would select 1, which allows us to get rid of the select + // instruction. + // + // To see if we can do so, we do some symbolic execution with ConstantRange. + // Specifically, we compute the range of values that Cond0 could take when + // Cond == false. Then we successively transform the range until we obtain + // the range of values that CtlzOp could take. + // + // Conceptually, we follow the def-use chain backward from Cond0 while + // transforming the range for Cond0 until we meet the common ancestor of Cond0 + // and CtlzOp. Then we follow the def-use chain forward until we obtain the + // range for CtlzOp. That said, we only follow at most one ancestor from + // Cond0. Likewise, we only follow at most one ancestor from CtrlOp. + + ConstantRange CR = ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred), *Cond1); + + // Match the operation that's used to compute CtlzOp from CommonAncestor. If + // a match is found, execute the operation on CR, update CR, and return true. + // Otherwise, return false. + auto MatchForward = [&](Value *CommonAncestor) { + const APInt *C = nullptr; + if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) { + CR = CR.add(*C); + return true; + } + if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) { + CR = ConstantRange(*C).sub(CR); + return true; + } + if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) { + CR = CR.binaryNot(); + return true; + } + return false; + }; + + const APInt *C = nullptr; + Value *CommonAncestor; + if (match(Cond0, m_Add(m_Specific(CtlzOp), m_APInt(C)))) { + // We have Cond0's parent == CtlzOp. + CR = CR.sub(*C); + } else if (MatchForward(Cond0)) { + // We have Cond0 == CtlzOp's parent. CR has been updated. + } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) { + CR = CR.sub(*C); + if (!MatchForward(CommonAncestor)) + return false; + // We have Cond0's parent == CtlzOp's parent. CR has been updated. + } else { + return false; + } + + // Return true if all the values in the range are either 0 or negative (if + // treated as signed). We do so by evaluating: + // + // CR - 1 u>= (1 << BitWidth) - 1. + APInt IntMax = APInt::getSignMask(BitWidth) - 1; + CR = CR.sub(APInt(BitWidth, 1)); + return CR.icmp(ICmpInst::ICMP_UGE, IntMax); +} + +// Transform the std::bit_ceil(X) pattern like: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %sub = sub i32 32, %ctlz +// %shl = shl i32 1, %sub +// %ugt = icmp ugt i32 %x, 1 +// %sel = select i1 %ugt, i32 %shl, i32 1 +// +// into: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %neg = sub i32 0, %ctlz +// %masked = and i32 %ctlz, 31 +// %shl = shl i32 1, %sub +// +// Note that the select is optimized away while the shift count is masked with +// 31. We handle some variations of the input operand like std::bit_ceil(X + +// 1). +static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { + Type *SelType = SI.getType(); + unsigned BitWidth = SelType->getScalarSizeInBits(); + + Value *FalseVal = SI.getFalseValue(); + Value *TrueVal = SI.getTrueValue(); + ICmpInst::Predicate Pred; + const APInt *Cond1; + Value *Cond0, *Ctlz, *CtlzOp; + if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1)))) + return nullptr; + + if (match(TrueVal, m_One())) { + std::swap(FalseVal, TrueVal); + Pred = CmpInst::getInversePredicate(Pred); + } + + if (!match(FalseVal, m_One()) || + !match(TrueVal, + m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), + m_Value(Ctlz)))))) || + !match(Ctlz, m_Intrinsic(m_Value(CtlzOp), m_Zero())) || + !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth)) + return nullptr; + + // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a + // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth + // is an integer constant. Masking with BitWidth-1 comes free on some + // hardware as part of the shift instruction. + Value *Neg = Builder.CreateNeg(Ctlz); + Value *Masked = + Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1)); + return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1), + Masked); +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3590,5 +3718,8 @@ if (sinkNotIntoOtherHandOfLogicalOp(SI)) return &SI; + if (Instruction *I = foldBitCeil(SI, Builder)) + return I; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/bit_ceil.ll b/llvm/test/Transforms/InstCombine/bit_ceil.ll --- a/llvm/test/Transforms/InstCombine/bit_ceil.ll +++ b/llvm/test/Transforms/InstCombine/bit_ceil.ll @@ -6,10 +6,9 @@ ; CHECK-LABEL: @bit_ceil_32( ; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X:%.*]], -1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0:![0-9]+]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]] -; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i32 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]] ; CHECK-NEXT: ret i32 [[SEL]] ; %dec = add i32 %x, -1 @@ -26,10 +25,9 @@ ; CHECK-LABEL: @bit_ceil_64( ; CHECK-NEXT: [[DEC:%.*]] = add i64 [[X:%.*]], -1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[DEC]], i1 false), !range [[RNG1:![0-9]+]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i64 64, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 1, [[SUB]] -; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i64 [[X]], 1 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i64 [[SHL]], i64 1 +; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i64 0, [[CTLZ]] +; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[TMP1]], 63 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i64 1, [[TMP2]] ; CHECK-NEXT: ret i64 [[SEL]] ; %dec = add i64 %x, -1 @@ -47,11 +45,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], -2 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]] -; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], -3 -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[ADD]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; entry: @@ -69,11 +65,9 @@ define i32 @bit_ceil_32_plus_1(i32 %x) { ; CHECK-LABEL: @bit_ceil_32_plus_1( ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[X:%.*]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]] -; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X]], -1 -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[DEC]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]] ; CHECK-NEXT: ret i32 [[SEL]] ; %ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 false) @@ -90,10 +84,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], 1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]] -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[X]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; entry: @@ -112,11 +105,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[X:%.*]], -1 ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]] -; CHECK-NEXT: [[NOTSUB:%.*]] = add i32 [[X]], -1 -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[NOTSUB]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; entry: @@ -136,10 +127,9 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[SUB:%.*]] = sub i32 -2, [[X:%.*]] ; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]] -; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]] -; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[X]], -2 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31 +; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; entry: