diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -836,6 +836,27 @@ // TODO: Mask high bits with 'and'. } + // trunc (*shr (trunc A), C) --> trunc(*shr A, C) + if (match(Src, m_OneUse(m_Shr(m_Trunc(m_Value(A)), m_Constant(C))))) { + unsigned MaxShiftAmt = SrcWidth - DestWidth; + + // If the shift is small enough, all zero/sign bits created by the shift are + // removed by the trunc. + // TODO: Support passing through undef shift amounts - these currently get + // zero'd by getIntegerCast. + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { + auto *OldShift = cast(Src); + auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); + bool IsExact = OldShift->isExact(); + Value *Shift = + OldShift->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) + : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } + } + if (Instruction *I = narrowBinOp(Trunc)) return I; diff --git a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll --- a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll +++ b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll @@ -9,9 +9,8 @@ define i8 @trunc_lshr_trunc(i64 %a) { ; CHECK-LABEL: @trunc_lshr_trunc( -; CHECK-NEXT: [[B:%.*]] = trunc i64 [[A:%.*]] to i32 -; CHECK-NEXT: [[C:%.*]] = lshr i32 [[B]], 8 -; CHECK-NEXT: [[D:%.*]] = trunc i32 [[C]] to i8 +; CHECK-NEXT: [[C1:%.*]] = lshr i64 [[A:%.*]], 8 +; CHECK-NEXT: [[D:%.*]] = trunc i64 [[C1]] to i8 ; CHECK-NEXT: ret i8 [[D]] ; %b = trunc i64 %a to i32 @@ -22,9 +21,8 @@ define <2 x i8> @trunc_lshr_trunc_uniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_uniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -35,9 +33,8 @@ define <2 x i8> @trunc_lshr_trunc_nonuniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_nonuniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -48,13 +45,12 @@ define <2 x i8> @trunc_lshr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_uniform_undef( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> - %c = lshr <2 x i32> %b, + %c = lshr <2 x i32> %b, %d = trunc <2 x i32> %c to <2 x i8> ret <2 x i8> %d } @@ -87,9 +83,8 @@ define i8 @trunc_ashr_trunc(i64 %a) { ; CHECK-LABEL: @trunc_ashr_trunc( -; CHECK-NEXT: [[B:%.*]] = trunc i64 [[A:%.*]] to i32 -; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[B]], 8 -; CHECK-NEXT: [[D:%.*]] = trunc i32 [[TMP1]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A:%.*]], 8 +; CHECK-NEXT: [[D:%.*]] = trunc i64 [[TMP1]] to i8 ; CHECK-NEXT: ret i8 [[D]] ; %b = trunc i64 %a to i32 @@ -100,9 +95,8 @@ define i8 @trunc_ashr_trunc_exact(i64 %a) { ; CHECK-LABEL: @trunc_ashr_trunc_exact( -; CHECK-NEXT: [[B:%.*]] = trunc i64 [[A:%.*]] to i32 -; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i32 [[B]], 8 -; CHECK-NEXT: [[D:%.*]] = trunc i32 [[TMP1]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i64 [[A:%.*]], 8 +; CHECK-NEXT: [[D:%.*]] = trunc i64 [[TMP1]] to i8 ; CHECK-NEXT: ret i8 [[D]] ; %b = trunc i64 %a to i32 @@ -113,9 +107,8 @@ define <2 x i8> @trunc_ashr_trunc_uniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_uniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[TMP1]] to <2 x i8> +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[TMP1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> @@ -126,22 +119,20 @@ define <2 x i8> @trunc_ashr_trunc_nonuniform(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_nonuniform( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = ashr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32> - %c = ashr <2 x i32> %b, + %c = ashr <2 x i32> %b, %d = trunc <2 x i32> %c to <2 x i8> ret <2 x i8> %d } define <2 x i8> @trunc_ashr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_uniform_undef( -; CHECK-NEXT: [[B:%.*]] = trunc <2 x i64> [[A:%.*]] to <2 x i32> -; CHECK-NEXT: [[C:%.*]] = ashr <2 x i32> [[B]], -; CHECK-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; %b = trunc <2 x i64> %a to <2 x i32>