diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -851,6 +851,23 @@ // TODO: Mask high bits with 'and'. } + // trunc (*shr (trunc A), C) --> trunc(*shr A, C) + if (match(Src, m_OneUse(m_Shr(m_Trunc(m_Value(A)), m_APInt(C))))) { + Type *AType = A->getType(); + unsigned AWidth = AType->getScalarSizeInBits(); + unsigned MaxShiftAmt = std::min(DestWidth, AWidth - DestWidth); + + // If the shift is small enough, all zero/sign bits created by the shift are + // removed by the trunc. + if (C->ule(MaxShiftAmt)) { + auto *ShAmt = ConstantInt::get(AType, C->getZExtValue()); + Value *Shift = cast(Src)->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt) + : Builder.CreateLShr(A, ShAmt); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } + } + if (Instruction *I = narrowBinOp(Trunc)) return I; diff --git a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll --- a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll +++ b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll @@ -9,9 +9,8 @@ define i8 @trunc_lshr_trunc(i64 %a) { ; CHECK-LABEL: @trunc_lshr_trunc( -; CHECK-NEXT: [[B:%.*]] = trunc i64 [[A:%.*]] to i32 -; CHECK-NEXT: [[C:%.*]] = lshr i32 [[B]], 8 -; CHECK-NEXT: [[D:%.*]] = trunc i32 [[C]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A:%.*]], 8 +; CHECK-NEXT: [[D:%.*]] = trunc i64 [[TMP1]] to i8 ; CHECK-NEXT: ret i8 [[D]] ; %b = trunc i64 %a to i32 @@ -22,9 +21,8 @@ define <2 x i8> @trunc_lshr_trunc_uniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_uniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[TMP1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -61,9 +59,8 @@ define i8 @trunc_ashr_trunc(i64 %a) { ; CHECK-LABEL: @trunc_ashr_trunc( -; CHECK-NEXT: [[B:%.*]] = trunc i64 [[A:%.*]] to i32 -; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[B]], 8 -; CHECK-NEXT: [[D:%.*]] = trunc i32 [[TMP1]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A:%.*]], 8 +; CHECK-NEXT: [[D:%.*]] = trunc i64 [[TMP1]] to i8 ; CHECK-NEXT: ret i8 [[D]] ; %b = trunc i64 %a to i32 @@ -74,9 +71,8 @@ define <2 x i8> @trunc_ashr_trunc_uniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_uniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[TMP1]] to <2 x i8> +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[TMP1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32>