Index: lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCasts.cpp +++ lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -587,15 +587,18 @@ // the sign bit of the original value; performing ashr instead of lshr // generates bits of the same value as the sign bit. if (Src->hasOneUse() && - match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst))) && - cast(Src)->getOperand(0)->hasOneUse()) { + match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) { + Value *SExt = cast(Src)->getOperand(0); + const unsigned SExtSize = SExt->getType()->getPrimitiveSizeInBits(); const unsigned ASize = A->getType()->getPrimitiveSizeInBits(); + unsigned ShiftAmt = Cst->getZExtValue(); // This optimization can be only performed when zero bits generated by // the original lshr aren't pulled into the value after truncation, so we - // can only shift by values smaller than the size of destination type (in - // bits). - if (Cst->getValue().ult(ASize)) { - Value *Shift = Builder->CreateAShr(A, Cst->getZExtValue()); + // can only shift by values no larger than the number of extension bits. + if (SExt->hasOneUse() && ShiftAmt <= SExtSize - ASize) { + // If shifting by the size of the original value in bits or more, it is + // being filled with the sign bit, so shift by ASize-1 to avoid ub. + Value *Shift = Builder->CreateAShr(A, std::min(ShiftAmt, ASize-1)); Shift->takeName(Src); return CastInst::CreateIntegerCast(Shift, CI.getType(), true); } Index: test/Transforms/InstCombine/cast.ll =================================================================== --- test/Transforms/InstCombine/cast.ll +++ test/Transforms/InstCombine/cast.ll @@ -1432,3 +1432,38 @@ %tmp6 = bitcast <4 x half> to <2 x i32> ret <2 x i32> %tmp6 } + +; Do not optimize to ashr i64 (shift by 48 > 96 - 64) +define i64 @test91(i64 %A) { +; CHECK-LABEL: @test91( +; CHECK-NEXT: [[B:%.*]] = sext i64 %A to i96 +; CHECK-NEXT: [[C:%.*]] = lshr i96 [[B]], 48 +; CHECK-NEXT: [[D:%.*]] = trunc i96 [[C]] to i64 +; CHECK-NEXT: ret i64 [[D]] + %B = sext i64 %A to i96 + %C = lshr i96 %B, 48 + %D = trunc i96 %C to i64 + ret i64 %D +} + +; Do optimize to ashr i64 (shift by 32 <= 96 - 64) +define i64 @test92(i64 %A) { +; CHECK-LABEL: @test92( +; CHECK-NEXT: [[C:%.*]] = ashr i64 %A, 32 +; CHECK-NEXT: ret i64 [[C]] + %B = sext i64 %A to i96 + %C = lshr i96 %B, 32 + %D = trunc i96 %C to i64 + ret i64 %D +} + +; When optimizing to ashr i32, don't shift by more than 31. +define i32 @test93(i32 %A) { +; CHECK-LABEL: @test93( +; CHECK-NEXT: [[C:%.*]] = ashr i32 %A, 31 +; CHECK-NEXT: ret i32 [[C]] + %B = sext i32 %A to i96 + %C = lshr i96 %B, 64 + %D = trunc i96 %C to i32 + ret i32 %D +}