diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1526,13 +1526,20 @@ ShAmt); } - // If the input is a trunc from the destination type, then turn sext(trunc(x)) - // into shifts. Value *X; - if (match(Src, m_OneUse(m_Trunc(m_Value(X)))) && X->getType() == DestTy) { - // sext (trunc X) --> ashr (shl X, C), C - Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); - return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShAmt), ShAmt); + if (match(Src, m_Trunc(m_Value(X)))) { + // If the input has more sign bits than bits truncated, then convert + // directly to final type. + unsigned XBitSize = X->getType()->getScalarSizeInBits(); + if (ComputeNumSignBits(X, 0, &CI) > XBitSize - SrcBitSize) + return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true); + + // If input is a trunc from the destination type, then convert into shifts. + if (Src->hasOneUse() && X->getType() == DestTy) { + // sext (trunc X) --> ashr (shl X, C), C + Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); + return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShAmt), ShAmt); + } } if (ICmpInst *ICI = dyn_cast(Src)) diff --git a/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll b/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll --- a/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll +++ b/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll @@ -13,8 +13,7 @@ ; CHECK-LABEL: @t0( ; CHECK-NEXT: [[A:%.*]] = ashr i8 [[X:%.*]], 5 ; CHECK-NEXT: call void @use8(i8 [[A]]) -; CHECK-NEXT: [[B:%.*]] = trunc i8 [[A]] to i4 -; CHECK-NEXT: [[C:%.*]] = sext i4 [[B]] to i16 +; CHECK-NEXT: [[C:%.*]] = sext i8 [[A]] to i16 ; CHECK-NEXT: ret i16 [[C]] ; %a = ashr i8 %x, 5 @@ -28,8 +27,7 @@ ; CHECK-LABEL: @t1( ; CHECK-NEXT: [[A:%.*]] = ashr i8 [[X:%.*]], 4 ; CHECK-NEXT: call void @use8(i8 [[A]]) -; CHECK-NEXT: [[B:%.*]] = trunc i8 [[A]] to i4 -; CHECK-NEXT: [[C:%.*]] = sext i4 [[B]] to i16 +; CHECK-NEXT: [[C:%.*]] = sext i8 [[A]] to i16 ; CHECK-NEXT: ret i16 [[C]] ; %a = ashr i8 %x, 4 @@ -59,8 +57,7 @@ ; CHECK-LABEL: @t3_vec( ; CHECK-NEXT: [[A:%.*]] = ashr <2 x i8> [[X:%.*]], ; CHECK-NEXT: call void @usevec(<2 x i8> [[A]]) -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i8> [[A]] to <2 x i4> -; CHECK-NEXT: [[C:%.*]] = sext <2 x i4> [[B]] to <2 x i16> +; CHECK-NEXT: [[C:%.*]] = sext <2 x i8> [[A]] to <2 x i16> ; CHECK-NEXT: ret <2 x i16> [[C]] ; %a = ashr <2 x i8> %x, @@ -91,7 +88,7 @@ ; CHECK-NEXT: call void @use8(i8 [[A]]) ; CHECK-NEXT: [[B:%.*]] = trunc i8 [[A]] to i4 ; CHECK-NEXT: call void @use4(i4 [[B]]) -; CHECK-NEXT: [[C:%.*]] = sext i4 [[B]] to i16 +; CHECK-NEXT: [[C:%.*]] = sext i8 [[A]] to i16 ; CHECK-NEXT: ret i16 [[C]] ; %a = ashr i8 %x, 5 @@ -106,8 +103,7 @@ ; CHECK-LABEL: @narrow_source_matching_signbits( ; CHECK-NEXT: [[M:%.*]] = and i32 [[X:%.*]], 7 ; CHECK-NEXT: [[A:%.*]] = shl nsw i32 -1, [[M]] -; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i8 -; CHECK-NEXT: [[C:%.*]] = sext i8 [[B]] to i64 +; CHECK-NEXT: [[C:%.*]] = sext i32 [[A]] to i64 ; CHECK-NEXT: ret i64 [[C]] ; %m = and i32 %x, 7 @@ -117,6 +113,8 @@ ret i64 %c } +; negative test - not enough sign-bits + define i64 @narrow_source_not_matching_signbits(i32 %x) { ; CHECK-LABEL: @narrow_source_not_matching_signbits( ; CHECK-NEXT: [[M:%.*]] = and i32 [[X:%.*]], 8 @@ -136,8 +134,7 @@ ; CHECK-LABEL: @wide_source_matching_signbits( ; CHECK-NEXT: [[M:%.*]] = and i32 [[X:%.*]], 7 ; CHECK-NEXT: [[A:%.*]] = shl nsw i32 -1, [[M]] -; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i8 -; CHECK-NEXT: [[C:%.*]] = sext i8 [[B]] to i24 +; CHECK-NEXT: [[C:%.*]] = trunc i32 [[A]] to i24 ; CHECK-NEXT: ret i24 [[C]] ; %m = and i32 %x, 7 @@ -147,6 +144,8 @@ ret i24 %c } +; negative test - not enough sign-bits + define i24 @wide_source_not_matching_signbits(i32 %x) { ; CHECK-LABEL: @wide_source_not_matching_signbits( ; CHECK-NEXT: [[M2:%.*]] = and i32 [[X:%.*]], 8 @@ -165,9 +164,8 @@ define i32 @same_source_matching_signbits(i32 %x) { ; CHECK-LABEL: @same_source_matching_signbits( ; CHECK-NEXT: [[M:%.*]] = and i32 [[X:%.*]], 7 -; CHECK-NEXT: [[TMP1:%.*]] = shl i32 -16777216, [[M]] -; CHECK-NEXT: [[C:%.*]] = ashr exact i32 [[TMP1]], 24 -; CHECK-NEXT: ret i32 [[C]] +; CHECK-NEXT: [[A:%.*]] = shl nsw i32 -1, [[M]] +; CHECK-NEXT: ret i32 [[A]] ; %m = and i32 %x, 7 %a = shl nsw i32 -1, %m @@ -176,6 +174,8 @@ ret i32 %c } +; negative test - not enough sign-bits + define i32 @same_source_not_matching_signbits(i32 %x) { ; CHECK-LABEL: @same_source_not_matching_signbits( ; CHECK-NEXT: [[M2:%.*]] = and i32 [[X:%.*]], 8 @@ -196,8 +196,7 @@ ; CHECK-NEXT: [[A:%.*]] = shl nsw i32 -1, [[M]] ; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i8 ; CHECK-NEXT: call void @use8(i8 [[B]]) -; CHECK-NEXT: [[C:%.*]] = sext i8 [[B]] to i32 -; CHECK-NEXT: ret i32 [[C]] +; CHECK-NEXT: ret i32 [[A]] ; %m = and i32 %x, 7 %a = shl nsw i32 -1, %m