diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -637,6 +637,14 @@ break; } case Instruction::AShr: { + unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + + // If we only want bits that already match the signbit then we don't need + // to shift. + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros(); + if (SignBits >= NumHiDemandedBits) + return I->getOperand(0); + // If this is an arithmetic shift right and only the low-bit is set, we can // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless @@ -648,11 +656,6 @@ return InsertNewInstWith(NewVal, *I); } - // If the sign bit is the only bit demanded by this ashr, then there is no - // need to do it, the shift doesn't change the high bit. - if (DemandedMask.isSignMask()) - return I->getOperand(0); - const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); @@ -672,8 +675,6 @@ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; - unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now plus sign bits. APInt HighBits(APInt::getHighBitsSet( diff --git a/llvm/test/Transforms/InstCombine/ashr-demand.ll b/llvm/test/Transforms/InstCombine/ashr-demand.ll --- a/llvm/test/Transforms/InstCombine/ashr-demand.ll +++ b/llvm/test/Transforms/InstCombine/ashr-demand.ll @@ -5,9 +5,8 @@ define i32 @srem2_ashr_mask(i32 %a0) { ; CHECK-LABEL: @srem2_ashr_mask( -; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A0:%.*]], -2147483647 -; CHECK-NEXT: [[ISNEG:%.*]] = icmp eq i32 [[TMP1]], -2147483647 -; CHECK-NEXT: [[MASK:%.*]] = select i1 [[ISNEG]], i32 2, i32 0 +; CHECK-NEXT: [[SREM:%.*]] = srem i32 [[A0:%.*]], 2 +; CHECK-NEXT: [[MASK:%.*]] = and i32 [[SREM]], 2 ; CHECK-NEXT: ret i32 [[MASK]] ; %srem = srem i32 %a0, 2 ; result = (1,0,-1) num signbits = 31 @@ -31,9 +30,8 @@ define <2 x i32> @srem2_ashr_mask_vector(<2 x i32> %a0) { ; CHECK-LABEL: @srem2_ashr_mask_vector( -; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A0:%.*]], -; CHECK-NEXT: [[ISNEG:%.*]] = icmp eq <2 x i32> [[TMP1]], -; CHECK-NEXT: [[MASK:%.*]] = select <2 x i1> [[ISNEG]], <2 x i32> , <2 x i32> zeroinitializer +; CHECK-NEXT: [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], +; CHECK-NEXT: [[MASK:%.*]] = and <2 x i32> [[SREM]], ; CHECK-NEXT: ret <2 x i32> [[MASK]] ; %srem = srem <2 x i32> %a0, @@ -45,8 +43,7 @@ define <2 x i32> @srem2_ashr_mask_vector_nonconstant(<2 x i32> %a0, <2 x i32> %a1) { ; CHECK-LABEL: @srem2_ashr_mask_vector_nonconstant( ; CHECK-NEXT: [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], -; CHECK-NEXT: [[ASHR:%.*]] = ashr <2 x i32> [[SREM]], [[A1:%.*]] -; CHECK-NEXT: [[MASK:%.*]] = and <2 x i32> [[ASHR]], +; CHECK-NEXT: [[MASK:%.*]] = and <2 x i32> [[SREM]], ; CHECK-NEXT: ret <2 x i32> [[MASK]] ; %srem = srem <2 x i32> %a0,