diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -659,13 +659,36 @@ continue; for (Value *M : llvm::make_range(ECs.member_begin(I), ECs.member_end())) { - if (!isa(M)) + auto *MI = dyn_cast(M); + if (!MI) continue; Type *Ty = M->getType(); if (Roots.count(M)) - Ty = cast(M)->getOperand(0)->getType(); - if (MinBW < Ty->getScalarSizeInBits()) - MinBWs[cast(M)] = MinBW; + Ty = MI->getOperand(0)->getType(); + + if (MinBW >= Ty->getScalarSizeInBits()) + continue; + + // Check if any of M's operands demand more bits than MinBW or if it is a + // constant cannot be safely truncated to MinBW. In any of those cases, M + // cannot be performed safely in MinBW. + if (any_of(MI->operands(), [&DB, MinBW](Use &U) { + if (auto *CI = dyn_cast(U)) { + APInt I = CI->getValue(); + // Shift will result in poison. + if (isa(U.getUser()) && + I.uge(MinBW)) + return true; + + if (I.trunc(MinBW).zext(I.getBitWidth()) == I) + return false; + } + uint64_t BW = bit_width(DB.getDemandedBits(&U).getZExtValue()); + return bit_ceil(BW) > MinBW; + })) + continue; + + MinBWs[MI] = MinBW; } } diff --git a/llvm/test/Transforms/LoopVectorize/trunc-shifts.ll b/llvm/test/Transforms/LoopVectorize/trunc-shifts.ll --- a/llvm/test/Transforms/LoopVectorize/trunc-shifts.ll +++ b/llvm/test/Transforms/LoopVectorize/trunc-shifts.ll @@ -17,17 +17,15 @@ ; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[OFFSET_IDX:%.*]] = trunc i32 [[INDEX]] to i8 ; CHECK-NEXT: [[TMP0:%.*]] = add i8 [[OFFSET_IDX]], 0 -; CHECK-NEXT: [[TMP1:%.*]] = trunc <4 x i32> [[BROADCAST_SPLAT]] to <4 x i8> -; CHECK-NEXT: [[TMP2:%.*]] = lshr <4 x i8> [[TMP1]], -; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = trunc <4 x i32> [[TMP3]] to <4 x i8> -; CHECK-NEXT: [[TMP5:%.*]] = zext i8 [[TMP0]] to i64 -; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i8, ptr [[DST]], i64 [[TMP5]] -; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[TMP6]], i32 0 -; CHECK-NEXT: store <4 x i8> [[TMP4]], ptr [[TMP7]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = lshr <4 x i32> [[BROADCAST_SPLAT]], +; CHECK-NEXT: [[TMP2:%.*]] = trunc <4 x i32> [[TMP1]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = zext i8 [[TMP0]] to i64 +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[DST]], i64 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0 +; CHECK-NEXT: store <4 x i8> [[TMP2]], ptr [[TMP5]], align 8 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4 -; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], 100 -; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] +; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 100 +; CHECK-NEXT: br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 100, 100 ; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]] @@ -80,17 +78,15 @@ ; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[OFFSET_IDX:%.*]] = trunc i32 [[INDEX]] to i8 ; CHECK-NEXT: [[TMP0:%.*]] = add i8 [[OFFSET_IDX]], 0 -; CHECK-NEXT: [[TMP1:%.*]] = trunc <4 x i32> [[BROADCAST_SPLAT]] to <4 x i8> -; CHECK-NEXT: [[TMP2:%.*]] = shl <4 x i8> [[TMP1]], -; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = trunc <4 x i32> [[TMP3]] to <4 x i8> -; CHECK-NEXT: [[TMP5:%.*]] = zext i8 [[TMP0]] to i64 -; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i8, ptr [[DST]], i64 [[TMP5]] -; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[TMP6]], i32 0 -; CHECK-NEXT: store <4 x i8> [[TMP4]], ptr [[TMP7]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = shl <4 x i32> [[BROADCAST_SPLAT]], +; CHECK-NEXT: [[TMP2:%.*]] = trunc <4 x i32> [[TMP1]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = zext i8 [[TMP0]] to i64 +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[DST]], i64 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0 +; CHECK-NEXT: store <4 x i8> [[TMP2]], ptr [[TMP5]], align 8 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4 -; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], 100 -; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] +; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 100 +; CHECK-NEXT: br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 100, 100 ; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]] @@ -143,17 +139,15 @@ ; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[OFFSET_IDX:%.*]] = trunc i32 [[INDEX]] to i8 ; CHECK-NEXT: [[TMP0:%.*]] = add i8 [[OFFSET_IDX]], 0 -; CHECK-NEXT: [[TMP1:%.*]] = trunc <4 x i32> [[BROADCAST_SPLAT]] to <4 x i8> -; CHECK-NEXT: [[TMP2:%.*]] = ashr <4 x i8> [[TMP1]], -; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i8> [[TMP2]] to <4 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = trunc <4 x i32> [[TMP3]] to <4 x i8> -; CHECK-NEXT: [[TMP5:%.*]] = zext i8 [[TMP0]] to i64 -; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i8, ptr [[DST]], i64 [[TMP5]] -; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[TMP6]], i32 0 -; CHECK-NEXT: store <4 x i8> [[TMP4]], ptr [[TMP7]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = ashr <4 x i32> [[BROADCAST_SPLAT]], +; CHECK-NEXT: [[TMP2:%.*]] = trunc <4 x i32> [[TMP1]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = zext i8 [[TMP0]] to i64 +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr [[DST]], i64 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i32 0 +; CHECK-NEXT: store <4 x i8> [[TMP2]], ptr [[TMP5]], align 8 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4 -; CHECK-NEXT: [[TMP8:%.*]] = icmp eq i32 [[INDEX_NEXT]], 100 -; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]] +; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i32 [[INDEX_NEXT]], 100 +; CHECK-NEXT: br i1 [[TMP6]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i32 100, 100 ; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]