diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -851,6 +851,28 @@ // TODO: Mask high bits with 'and'. } + // trunc (*shr (trunc A), C) --> trunc(*shr A, C) + Constant *C0; + if (match(Src, m_OneUse(m_Shr(m_Trunc(m_Value(A)), m_Constant(C0))))) { + Type *AType = A->getType(); + unsigned AWidth = AType->getScalarSizeInBits(); + unsigned MaxShiftAmt = std::min(DestWidth, AWidth - DestWidth); + + // If the shift is small enough, all zero/sign bits created by the shift are + // removed by the trunc. + if (match(C0, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { + auto *OldShift = cast(Src); + auto *ShAmt = ConstantExpr::getIntegerCast(C0, AType, true); + bool IsExact = OldShift->isExact(); + Value *Shift = + OldShift->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) + : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } + } + if (Instruction *I = narrowBinOp(Trunc)) return I; diff --git a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll --- a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll +++ b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll @@ -9,9 +9,8 @@ define i8 @trunc_lshr_trunc(i64 %a) { ; CHECK-LABEL: @trunc_lshr_trunc( -; CHECK-NEXT: [[B:%.*]] = trunc i64 [[A:%.*]] to i32 -; CHECK-NEXT: [[C:%.*]] = lshr i32 [[B]], 8 -; CHECK-NEXT: [[D:%.*]] = trunc i32 [[C]] to i8 +; CHECK-NEXT: [[C1:%.*]] = lshr i64 [[A:%.*]], 8 +; CHECK-NEXT: [[D:%.*]] = trunc i64 [[C1]] to i8 ; CHECK-NEXT: ret i8 [[D]] ; %b = trunc i64 %a to i32 @@ -22,9 +21,8 @@ define <2 x i8> @trunc_lshr_trunc_uniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_uniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -35,9 +33,8 @@ define <2 x i8> @trunc_lshr_trunc_nonuniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_nonuniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -48,9 +45,8 @@ define <2 x i8> @trunc_lshr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_uniform_undef( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -87,9 +83,8 @@ define i8 @trunc_ashr_trunc(i64 %a) { ; CHECK-LABEL: @trunc_ashr_trunc( -; CHECK-NEXT: [[B:%.*]] = trunc i64 [[A:%.*]] to i32 -; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[B]], 8 -; CHECK-NEXT: [[D:%.*]] = trunc i32 [[TMP1]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A:%.*]], 8 +; CHECK-NEXT: [[D:%.*]] = trunc i64 [[TMP1]] to i8 ; CHECK-NEXT: ret i8 [[D]] ; %b = trunc i64 %a to i32 @@ -98,11 +93,22 @@ ret i8 %d } +define i8 @trunc_ashr_trunc_exact(i64 %a) { +; CHECK-LABEL: @trunc_ashr_trunc_exact( +; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i64 [[A:%.*]], 8 +; CHECK-NEXT: [[D:%.*]] = trunc i64 [[TMP1]] to i8 +; CHECK-NEXT: ret i8 [[D]] +; + %b = trunc i64 %a to i32 + %c = ashr exact i32 %b, 8 + %d = trunc i32 %c to i8 + ret i8 %d +} + define <2 x i8> @trunc_ashr_trunc_uniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_uniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[TMP1]] to <2 x i8> +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[TMP1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -113,9 +119,8 @@ define <2 x i8> @trunc_ashr_trunc_nonuniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_nonuniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = ashr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -126,9 +131,8 @@ define <2 x i8> @trunc_ashr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_uniform_undef( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = ashr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32>