Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -2196,6 +2196,16 @@ if (Tmp == 1) return 1; // Early out. return std::min(Tmp, Tmp2)-1; + case Instruction::Mul: { + // The output of the Mul can at most twice the valid bits in the inputs. + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + if (Tmp == 1) return 1; // Early out. + Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); + if (Tmp2 == 1) return 1; + unsigned OutValidBits = (TyBits - Tmp + 1) + (TyBits - Tmp2 + 1); + return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1; + } + case Instruction::PHI: { const PHINode *PN = cast(U); unsigned NumIncomingValues = PN->getNumIncomingValues(); Index: lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCasts.cpp +++ lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -369,6 +369,21 @@ } } break; + case Instruction::AShr: + // If this is a truncate of an arithmetic shr, we can truncate it to a + // smaller ashr iff we know that all the bits from the sign bit of the + // original type and the sign bit of the truncate type are similar. + // TODO: It is enough to check that the bits we would be shifting in are + // similar to sign bit of the truncate type. + if (ConstantInt *CI = dyn_cast(I->getOperand(1))) { + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + if (CI->getLimitedValue(BitWidth) < BitWidth && + OrigBitWidth - BitWidth < + IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); + } + break; case Instruction::Trunc: // trunc(trunc(x)) -> trunc(x) return true; Index: lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -465,7 +465,7 @@ case Instruction::LShr: { const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { - uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); // Unsigned shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); @@ -521,9 +521,12 @@ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; + unsigned SignBits = ComputeNumSignBits(I->getOperand(0), 0, CxtI); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now. - APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + APInt HighBits(APInt::getHighBitsSet( + BitWidth, std::min(SignBits + ShiftAmt - 1, BitWidth))); Known.Zero.lshrInPlace(ShiftAmt); Known.One.lshrInPlace(ShiftAmt); Index: test/Transforms/InstCombine/trunc.ll =================================================================== --- test/Transforms/InstCombine/trunc.ll +++ test/Transforms/InstCombine/trunc.ll @@ -89,6 +89,21 @@ ret i32 %D } +define i16 @test6_ashr_mul(i8 %X, i8 %Y) { +; CHECK-LABEL: @test6_ashr_mul( +; CHECK-NEXT: [[A:%.*]] = sext i8 %X to i16 +; CHECK-NEXT: [[B:%.*]] = sext i8 %Y to i16 +; CHECK-NEXT: [[C:%.*]] = mul nsw i16 [[A]], [[B]] +; CHECK-NEXT: [[D:%.*]] = ashr i16 %C, 15 +; CHECK-NEXT: ret i16 %D + %A = sext i8 %X to i32 + %B = sext i8 %Y to i32 + %C = mul i32 %A, %B + %D = ashr i32 %C, 15 + %E = trunc i32 %D to i16 + ret i16 %E +} + define i92 @test7(i64 %A) { ; CHECK-LABEL: @test7( ; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 %A, 32