diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -205,6 +205,15 @@ const Instruction *CxtI = nullptr, const DominatorTree *DT = nullptr); +/// Get the upper bound on bit size for this Value \p Op as an unsigned integer. +/// i.e. x == zext(trunc(x to MaxSignificantBits) to bitwidth(x)). +unsigned ComputeMaxUnsignedSignificantBits(const Value *Op, + const DataLayout &DL, + unsigned Depth = 0, + AssumptionCache *AC = nullptr, + const Instruction *CxtI = nullptr, + const DominatorTree *DT = nullptr); + /// Map a call instruction to an intrinsic ID. Libcalls which have equivalent /// intrinsics are treated as-if they were intrinsics. Intrinsic::ID getIntrinsicForCallSite(const CallBase &CB, diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -489,6 +489,13 @@ return llvm::ComputeMaxSignificantBits(Op, DL, Depth, &AC, CxtI, &DT); } + unsigned + ComputeMaxUnsignedSignificantBits(const Value *Op, unsigned Depth = 0, + const Instruction *CxtI = nullptr) const { + return llvm::ComputeMaxUnsignedSignificantBits(Op, DL, Depth, &AC, CxtI, + &DT); + } + OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -396,6 +396,13 @@ return V->getType()->getScalarSizeInBits() - SignBits + 1; } +unsigned llvm::ComputeMaxUnsignedSignificantBits( + const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT) { + KnownBits KB = computeKnownBits(V, DL, Depth, AC, CxtI, DT); + return V->getType()->getScalarSizeInBits() - KB.countMinLeadingZeros(); +} + static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, bool NSW, const APInt &DemandedElts, KnownBits &KnownOut, KnownBits &Known2, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -840,13 +840,12 @@ } // Tries to perform -// (lshr (add (zext X), (zext Y)), K) -// -> (icmp ult (add X, Y), X) +// (lshr (add X, Y), K) +// -> (icmp ult (add (trunc X), (trunc Y)), (trunc X)) // where -// - The add's operands are zexts from a K-bits integer to a bigger type. +// - Only the K trailing bits of X and Y can be non-zero. // - The add is only used by the shr, or by iK (or narrower) truncates. -// - The lshr type has more than 2 bits (other types are boolean math). -// - K > 1 +// - K > 1 as we don't want to deal with boolean math here. // note that // - The resulting add cannot have nuw/nsw, else on overflow we get a // poison value and the transform isn't legal anymore. @@ -857,23 +856,18 @@ Value *ShiftAmt = I.getOperand(1); Type *Ty = I.getType(); - if (Ty->getScalarSizeInBits() < 3) - return nullptr; - const APInt *ShAmtAPInt = nullptr; Value *X = nullptr, *Y = nullptr; if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) || - !match(Add, - m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y)))))) + !match(Add, m_Add(m_OneUse(m_Value(X)), m_OneUse(m_Value(Y))))) return nullptr; const unsigned ShAmt = ShAmtAPInt->getZExtValue(); if (ShAmt == 1) return nullptr; - // X/Y are zexts from `ShAmt`-sized ints. - if (X->getType()->getScalarSizeInBits() != ShAmt || - Y->getType()->getScalarSizeInBits() != ShAmt) + if (ComputeMaxUnsignedSignificantBits(X, 0, &I) > ShAmt || + ComputeMaxUnsignedSignificantBits(Y, 0, &I) > ShAmt) return nullptr; // Make sure that `Add` is only used by `I` and `ShAmt`-truncates. @@ -893,6 +887,10 @@ Instruction *AddInst = cast(Add); Builder.SetInsertPoint(AddInst); + Type *OpTy = Builder.getIntNTy(ShAmt); + X = Builder.CreateTrunc(X, OpTy); + Y = Builder.CreateTrunc(Y, OpTy); + Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed"); Value *Overflow = Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow"); diff --git a/llvm/test/Transforms/InstCombine/shift-add.ll b/llvm/test/Transforms/InstCombine/shift-add.ll --- a/llvm/test/Transforms/InstCombine/shift-add.ll +++ b/llvm/test/Transforms/InstCombine/shift-add.ll @@ -493,14 +493,26 @@ define i32 @lshr_16_add_known_16_leading_zeroes(i32 %a, i32 %b) { ; CHECK-LABEL: @lshr_16_add_known_16_leading_zeroes( -; CHECK-NEXT: [[A16:%.*]] = and i32 [[A:%.*]], 65535 -; CHECK-NEXT: [[B16:%.*]] = and i32 [[B:%.*]], 65535 -; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[A16]], [[B16]] -; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[ADD]], 16 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[A:%.*]] to i16 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[B:%.*]] to i16 +; CHECK-NEXT: [[TMP3:%.*]] = xor i16 [[TMP1]], -1 +; CHECK-NEXT: [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i16 [[TMP3]], [[TMP2]] +; CHECK-NEXT: [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i32 ; CHECK-NEXT: ret i32 [[LSHR]] ; - %a16 = and i32 %a, 65535 ; 0x65535 - %b16 = and i32 %b, 65535 ; 0x65535 + %a16 = and i32 %a, 65535 ; 0xFFFF + %b16 = and i32 %b, 65535 ; 0xFFFF + %add = add i32 %a16, %b16 + %lshr = lshr i32 %add, 16 + ret i32 %lshr +} + +define i32 @lshr_16_add_known_17_leading_zeroes(i32 %a, i32 %b) { +; CHECK-LABEL: @lshr_16_add_known_17_leading_zeroes( +; CHECK-NEXT: ret i32 0 +; + %a16 = and i32 %a, 4095 ; 0xFFF + %b16 = and i32 %b, 4095 ; 0xFFF %add = add i32 %a16, %b16 %lshr = lshr i32 %add, 16 ret i32 %lshr @@ -594,10 +606,11 @@ define i64 @lshr_32_add_known_32_leading_zeroes(i64 %a, i64 %b) { ; CHECK-LABEL: @lshr_32_add_known_32_leading_zeroes( -; CHECK-NEXT: [[A32:%.*]] = and i64 [[A:%.*]], 4294967295 -; CHECK-NEXT: [[B32:%.*]] = and i64 [[B:%.*]], 4294967295 -; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[A32]], [[B32]] -; CHECK-NEXT: [[LSHR:%.*]] = lshr i64 [[ADD]], 32 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[A:%.*]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[B:%.*]] to i32 +; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP1]], -1 +; CHECK-NEXT: [[ADD_NARROWED_OVERFLOW:%.*]] = icmp ult i32 [[TMP3]], [[TMP2]] +; CHECK-NEXT: [[LSHR:%.*]] = zext i1 [[ADD_NARROWED_OVERFLOW]] to i64 ; CHECK-NEXT: ret i64 [[LSHR]] ; %a32 = and i64 %a, 4294967295 ; 0xFFFFFFFF @@ -607,6 +620,17 @@ ret i64 %lshr } +define i64 @lshr_32_add_known_33_leading_zeroes(i64 %a, i64 %b) { +; CHECK-LABEL: @lshr_32_add_known_33_leading_zeroes( +; CHECK-NEXT: ret i64 0 +; + %a32 = and i64 %a, 268435455 ; 0xFFFFFFF + %b32 = and i64 %b, 268435455 ; 0xFFFFFFF + %add = add i64 %a32, %b32 + %lshr = lshr i64 %add, 32 + ret i64 %lshr +} + define i64 @lshr_32_add_not_known_32_leading_zeroes(i64 %a, i64 %b) { ; ; CHECK-LABEL: @lshr_32_add_not_known_32_leading_zeroes(