diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -385,8 +385,26 @@ Known = KnownBits::commonBits(LHSKnown, RHSKnown); break; } - case Instruction::ZExt: case Instruction::Trunc: { + // If we do not demand the high bits of a right-shifted and truncated value, + // then we may be able to truncate it before the shift. + Value *X; + const APInt *C; + if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + // The shift amount must be valid (not poison) in the narrow type, and + // it must not be greater than the high bits demanded of the result. + if (C->ult(I->getType()->getScalarSizeInBits()) && + C->ule(DemandedMask.countLeadingZeros())) { + // trunc (lshr X, C) --> lshr (trunc X), C + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *Trunc = Builder.CreateTrunc(X, I->getType()); + return Builder.CreateLShr(Trunc, C->getZExtValue()); + } + } + } + LLVM_FALLTHROUGH; + case Instruction::ZExt: { unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); diff --git a/llvm/test/Transforms/InstCombine/trunc-demand.ll b/llvm/test/Transforms/InstCombine/trunc-demand.ll --- a/llvm/test/Transforms/InstCombine/trunc-demand.ll +++ b/llvm/test/Transforms/InstCombine/trunc-demand.ll @@ -6,9 +6,9 @@ define i6 @trunc_lshr(i8 %x) { ; CHECK-LABEL: @trunc_lshr( -; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 -; CHECK-NEXT: [[T:%.*]] = trunc i8 [[S]] to i6 -; CHECK-NEXT: [[R:%.*]] = and i6 [[T]], 14 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i8 [[X:%.*]] to i6 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i6 [[TMP1]], 2 +; CHECK-NEXT: [[R:%.*]] = and i6 [[TMP2]], 14 ; CHECK-NEXT: ret i6 [[R]] ; %s = lshr i8 %x, 2 @@ -17,12 +17,13 @@ ret i6 %r } +; The 'and' is eliminated. + define i6 @trunc_lshr_exact_mask(i8 %x) { ; CHECK-LABEL: @trunc_lshr_exact_mask( -; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 -; CHECK-NEXT: [[T:%.*]] = trunc i8 [[S]] to i6 -; CHECK-NEXT: [[R:%.*]] = and i6 [[T]], 15 -; CHECK-NEXT: ret i6 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc i8 [[X:%.*]] to i6 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i6 [[TMP1]], 2 +; CHECK-NEXT: ret i6 [[TMP2]] ; %s = lshr i8 %x, 2 %t = trunc i8 %s to i6 @@ -30,6 +31,8 @@ ret i6 %r } +; negative test - a high bit of x is in the result + define i6 @trunc_lshr_big_mask(i8 %x) { ; CHECK-LABEL: @trunc_lshr_big_mask( ; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 @@ -43,6 +46,8 @@ ret i6 %r } +; negative test - too many uses + define i6 @trunc_lshr_use1(i8 %x) { ; CHECK-LABEL: @trunc_lshr_use1( ; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 @@ -58,6 +63,8 @@ ret i6 %r } +; negative test - too many uses + define i6 @trunc_lshr_use2(i8 %x) { ; CHECK-LABEL: @trunc_lshr_use2( ; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 @@ -73,11 +80,13 @@ ret i6 %r } +; Splat vectors are ok. + define <2 x i7> @trunc_lshr_vec_splat(<2 x i16> %x) { ; CHECK-LABEL: @trunc_lshr_vec_splat( -; CHECK-NEXT: [[S:%.*]] = lshr <2 x i16> [[X:%.*]], -; CHECK-NEXT: [[T:%.*]] = trunc <2 x i16> [[S]] to <2 x i7> -; CHECK-NEXT: [[R:%.*]] = and <2 x i7> [[T]], +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i16> [[X:%.*]] to <2 x i7> +; CHECK-NEXT: [[TMP2:%.*]] = lshr <2 x i7> [[TMP1]], +; CHECK-NEXT: [[R:%.*]] = and <2 x i7> [[TMP2]], ; CHECK-NEXT: ret <2 x i7> [[R]] ; %s = lshr <2 x i16> %x, @@ -86,12 +95,13 @@ ret <2 x i7> %r } +; The 'and' is eliminated. + define <2 x i7> @trunc_lshr_vec_splat_exact_mask(<2 x i16> %x) { ; CHECK-LABEL: @trunc_lshr_vec_splat_exact_mask( -; CHECK-NEXT: [[S:%.*]] = lshr <2 x i16> [[X:%.*]], -; CHECK-NEXT: [[T:%.*]] = trunc <2 x i16> [[S]] to <2 x i7> -; CHECK-NEXT: [[R:%.*]] = and <2 x i7> [[T]], -; CHECK-NEXT: ret <2 x i7> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i16> [[X:%.*]] to <2 x i7> +; CHECK-NEXT: [[TMP2:%.*]] = lshr <2 x i7> [[TMP1]], +; CHECK-NEXT: ret <2 x i7> [[TMP2]] ; %s = lshr <2 x i16> %x, %t = trunc <2 x i16> %s to <2 x i7> @@ -99,6 +109,8 @@ ret <2 x i7> %r } +; negative test - the shift is too big for the narrow type + define <2 x i7> @trunc_lshr_big_shift(<2 x i16> %x) { ; CHECK-LABEL: @trunc_lshr_big_shift( ; CHECK-NEXT: [[S:%.*]] = lshr <2 x i16> [[X:%.*]], @@ -112,11 +124,13 @@ ret <2 x i7> %r } +; High bits could also be set rather than cleared. + define i6 @or_trunc_lshr(i8 %x) { ; CHECK-LABEL: @or_trunc_lshr( -; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 1 -; CHECK-NEXT: [[T:%.*]] = trunc i8 [[S]] to i6 -; CHECK-NEXT: [[R:%.*]] = or i6 [[T]], -32 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i8 [[X:%.*]] to i6 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i6 [[TMP1]], 1 +; CHECK-NEXT: [[R:%.*]] = or i6 [[TMP2]], -32 ; CHECK-NEXT: ret i6 [[R]] ; %s = lshr i8 %x, 1 @@ -127,9 +141,9 @@ define i6 @or_trunc_lshr_more(i8 %x) { ; CHECK-LABEL: @or_trunc_lshr_more( -; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 4 -; CHECK-NEXT: [[T:%.*]] = trunc i8 [[S]] to i6 -; CHECK-NEXT: [[R:%.*]] = or i6 [[T]], -4 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i8 [[X:%.*]] to i6 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i6 [[TMP1]], 4 +; CHECK-NEXT: [[R:%.*]] = or i6 [[TMP2]], -4 ; CHECK-NEXT: ret i6 [[R]] ; %s = lshr i8 %x, 4 @@ -138,6 +152,8 @@ ret i6 %r } +; negative test - need all high bits to be undemanded + define i6 @or_trunc_lshr_small_mask(i8 %x) { ; CHECK-LABEL: @or_trunc_lshr_small_mask( ; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 4