diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3163,6 +3163,53 @@ return nullptr; } +// Transform: +// +// X == 0 ? 0 : (1 << (32 - ctlz(X >> 1))) +// +// into +// +// X == 0 ? 0 : (SignBit >> ctlz(X)) +// +// This function returns SignBit >> ctlz(X). The caller is responsible for +// replacing one of the select operands. +static Instruction *foldBitFloor(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); + unsigned BitWidth = SelType->getScalarSizeInBits(); + + ICmpInst::Predicate Pred; + Value *Cond0; + if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_Zero()))) + return nullptr; + + if (!ICmpInst::isEquality(Pred)) + return nullptr; + + if (Pred == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + Value *CTLZ; + if (!match(TrueVal, m_Zero()) || + !match(FalseVal, + m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), + m_Value(CTLZ))))))) + return nullptr; + + if (!match(CTLZ, m_Intrinsic( + m_OneUse(m_LShr(m_Specific(Cond0), m_One())), m_Zero()))) + return nullptr; + + // Build SignBit >> CTLZ as a replacement for the FalseVal. + Value *NewCTLZ = + Builder.CreateIntrinsic(Intrinsic::ctlz, {CTLZ->getType()}, + {Cond0, cast(CTLZ)->getOperand(1)}); + return cast(Builder.CreateLShr( + ConstantInt::get(SelType, APInt::getSignedMinValue(BitWidth)), NewCTLZ)); +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3590,5 +3637,8 @@ if (sinkNotIntoOtherHandOfLogicalOp(SI)) return &SI; + if (Instruction *I = foldBitFloor(SI, Builder)) + return replaceOperand(SI, match(SI.getTrueValue(), m_Zero()) ? 2 : 1, I); + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/bit_floor.ll b/llvm/test/Transforms/InstCombine/bit_floor.ll --- a/llvm/test/Transforms/InstCombine/bit_floor.ll +++ b/llvm/test/Transforms/InstCombine/bit_floor.ll @@ -4,11 +4,9 @@ define i32 @bit_floor_32(i32 %x) { ; CHECK-LABEL: @bit_floor_32( ; CHECK-NEXT: [[EQ0:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[X]], 1 -; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[LSHR]], i1 false), !range [[RNG0:![0-9]+]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[EQ0]], i32 0, i32 [[SHL]] +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.ctlz.i32(i32 [[X]], i1 false), !range [[RNG0:![0-9]+]] +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 -2147483648, [[TMP1]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[EQ0]], i32 0, i32 [[TMP2]] ; CHECK-NEXT: ret i32 [[SEL]] ; %eq0 = icmp eq i32 %x, 0 @@ -23,11 +21,9 @@ define i64 @bit_floor_64(i64 %x) { ; CHECK-LABEL: @bit_floor_64( ; CHECK-NEXT: [[EQ0:%.*]] = icmp eq i64 [[X:%.*]], 0 -; CHECK-NEXT: [[LSHR:%.*]] = lshr i64 [[X]], 1 -; CHECK-NEXT: [[CTLZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[LSHR]], i1 false), !range [[RNG1:![0-9]+]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i64 64, [[CTLZ]] -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 1, [[SUB]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[EQ0]], i64 0, i64 [[SHL]] +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.ctlz.i64(i64 [[X]], i1 false), !range [[RNG1:![0-9]+]] +; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 -9223372036854775808, [[TMP1]] +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[EQ0]], i64 0, i64 [[TMP2]] ; CHECK-NEXT: ret i64 [[SEL]] ; %eq0 = icmp eq i64 %x, 0