diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1163,15 +1163,32 @@ } } - // Look for a "splat" mul pattern - it replicates bits across each half of - // a value, so a right shift is just a mask of the low bits: - // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1 - // TODO: Generalize to allow more than just half-width shifts? const APInt *MulC; - if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) && - BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && - MulC->logBase2() == ShAmtC) - return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); + if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) { + // Look for a "splat" mul pattern - it replicates bits across each half of + // a value, so a right shift is just a mask of the low bits: + // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1 + // TODO: Generalize to allow more than just half-width shifts? + if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmtC) + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); + + // The one-use check is not strictly necessary, but codegen may not be + // able to invert the transform and perf may suffer with an extra mul + // instruction. + if (Op0->hasOneUse()) { + APInt NewMulC = MulC->lshr(ShAmtC); + // if c is divisible by (1 << ShAmtC): + // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC) + if (MulC->eq(NewMulC.shl(ShAmtC))) { + auto *NewMul = + BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC)); + BinaryOperator *OrigMul = cast(Op0); + NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap()); + return NewMul; + } + } + } // Try to narrow a bswap: // (bswap (zext X)) >> C --> zext (bswap X >> C') diff --git a/llvm/test/Transforms/InstCombine/shift-logic.ll b/llvm/test/Transforms/InstCombine/shift-logic.ll --- a/llvm/test/Transforms/InstCombine/shift-logic.ll +++ b/llvm/test/Transforms/InstCombine/shift-logic.ll @@ -259,9 +259,8 @@ define i64 @lshr_mul(i64 %0) { ; CHECK-LABEL: @lshr_mul( -; CHECK-NEXT: [[TMP2:%.*]] = mul nuw i64 [[TMP0:%.*]], 52 -; CHECK-NEXT: [[TMP3:%.*]] = lshr exact i64 [[TMP2]], 2 -; CHECK-NEXT: ret i64 [[TMP3]] +; CHECK-NEXT: [[TMP2:%.*]] = mul nuw i64 [[TMP0:%.*]], 13 +; CHECK-NEXT: ret i64 [[TMP2]] ; %2 = mul nuw i64 %0, 52 %3 = lshr i64 %2, 2 @@ -270,9 +269,8 @@ define i64 @lshr_mul_nuw_nsw(i64 %0) { ; CHECK-LABEL: @lshr_mul_nuw_nsw( -; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[TMP0:%.*]], 52 -; CHECK-NEXT: [[TMP3:%.*]] = lshr exact i64 [[TMP2]], 2 -; CHECK-NEXT: ret i64 [[TMP3]] +; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[TMP0:%.*]], 13 +; CHECK-NEXT: ret i64 [[TMP2]] ; %2 = mul nuw nsw i64 %0, 52 %3 = lshr i64 %2, 2 @@ -281,9 +279,8 @@ define <4 x i32> @lshr_mul_vector(<4 x i32> %0) { ; CHECK-LABEL: @lshr_mul_vector( -; CHECK-NEXT: [[TMP2:%.*]] = mul nuw <4 x i32> [[TMP0:%.*]], -; CHECK-NEXT: [[TMP3:%.*]] = lshr exact <4 x i32> [[TMP2]], -; CHECK-NEXT: ret <4 x i32> [[TMP3]] +; CHECK-NEXT: [[TMP2:%.*]] = mul nuw <4 x i32> [[TMP0:%.*]], +; CHECK-NEXT: ret <4 x i32> [[TMP2]] ; %2 = mul nuw <4 x i32> %0, %3 = lshr <4 x i32> %2, @@ -324,3 +321,14 @@ %3 = lshr i64 %2, 2 ret i64 %3 } + +define i64 @lshr_mul_negative_nsw(i64 %0) { +; CHECK-LABEL: @lshr_mul_negative_nsw( +; CHECK-NEXT: [[TMP2:%.*]] = mul nsw i64 [[TMP0:%.*]], 52 +; CHECK-NEXT: [[TMP3:%.*]] = lshr exact i64 [[TMP2]], 2 +; CHECK-NEXT: ret i64 [[TMP3]] +; + %2 = mul nsw i64 %0, 52 + %3 = lshr i64 %2, 2 + ret i64 %3 +}