Index: llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -439,20 +439,78 @@ // Now we need to prove that adding IdxDiff to ValA won't overflow. bool Safe = false; + auto CheckFlags = [](Instruction *I, bool Signed) { + BinaryOperator *BinOpI = cast(I); + return (Signed && BinOpI->hasNoSignedWrap()) || + (!Signed && BinOpI->hasNoUnsignedWrap()); + }; + // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to // ValA, we're okay. if (OpB->getOpcode() == Instruction::Add && isa(OpB->getOperand(1)) && - IdxDiff.sle(cast(OpB->getOperand(1))->getSExtValue())) { - if (Signed) - Safe = cast(OpB)->hasNoSignedWrap(); - else - Safe = cast(OpB)->hasNoUnsignedWrap(); + IdxDiff.sle(cast(OpB->getOperand(1))->getSExtValue()) && + CheckFlags(OpB, Signed)) + Safe = true; + + // Second attempt: If both OpA and OpB is an add with NSW/NUW and with + // the same LHS operand, we can guarantee that the transformation is safe + // if we can prove that OpA won't overflow when IdxDiff added to the RHS + // of OpA. + // For example: + // %tmp7 = add nsw i32 %tmp2, %v0 + // %tmp8 = sext i32 %tmp7 to i64 + // ... + // %tmp11 = add nsw i32 %v0, 1 + // %tmp12 = add nsw i32 %tmp2, %tmp11 + // %tmp13 = sext i32 %tmp12 to i64 + // + // Both %tmp7 and %tmp2 has the nsw flag and the first operand + // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow + // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the + // nsw flag. + OpA = dyn_cast(ValA); + if (!Safe && OpA && OpA->getOpcode() == Instruction::Add && + OpB->getOpcode() == Instruction::Add && + OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) && + CheckFlags(OpB, Signed)) { + Value *RHSA = OpA->getOperand(1); + Value *RHSB = OpB->getOperand(1); + Instruction *OpRHSA = dyn_cast(RHSA); + Instruction *OpRHSB = dyn_cast(RHSB); + // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`. + if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add && + CheckFlags(OpRHSB, Signed) && isa(OpRHSB->getOperand(1))) { + int64_t CstVal = cast(OpRHSB->getOperand(1))->getSExtValue(); + if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal) + Safe = true; + } + // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`. + if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add && + CheckFlags(OpRHSA, Signed) && isa(OpRHSA->getOperand(1))) { + int64_t CstVal = cast(OpRHSA->getOperand(1))->getSExtValue(); + if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal) + Safe = true; + } + // Match `x +nsw/nuw (y +nsw/nuw c)` and + // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`. + if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add && + OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) && + CheckFlags(OpRHSB, Signed) && isa(OpRHSA->getOperand(1)) && + isa(OpRHSB->getOperand(1))) { + int64_t CstValA = + cast(OpRHSA->getOperand(1))->getSExtValue(); + int64_t CstValB = + cast(OpRHSB->getOperand(1))->getSExtValue(); + if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) && + IdxDiff.getSExtValue() == (CstValB - CstValA)) + Safe = true; + } } unsigned BitWidth = ValA->getType()->getScalarSizeInBits(); - // Second attempt: + // Third attempt: // If all set bits of IdxDiff or any higher order bit other than the sign bit // are known to be zero in ValA, we can add Diff to it while guaranteeing no // overflow of any sort. Index: llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll @@ -0,0 +1,165 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -o - -S -load-store-vectorizer -dce %s | FileCheck %s + +; Make sure LoadStoreVectorizer vectorizes the loads below. +; In order to prove that the vectorization is safe, it tries to +; match nested adds and find an expression that adds a constant +; value to an existing index and the result doesn't overflow. + +target triple = "x86_64--" + +define void @ld_v4i8_add_nsw(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) { +; CHECK-LABEL: @ld_v4i8_add_nsw( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = add nsw i32 [[V1:%.*]], [[TMP]] +; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>* +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1 +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0 +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1 +; CHECK-NEXT: [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2 +; CHECK-NEXT: [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3 +; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]] +; CHECK-NEXT: ret void +; +bb: + %tmp = add nsw i32 %v0, -1 + %tmp1 = add nsw i32 %v1, %tmp + %tmp2 = sext i32 %tmp1 to i64 + %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2 + %tmp4 = load i8, i8* %tmp3, align 1 + %tmp5 = add nsw i32 %v1, %v0 + %tmp6 = sext i32 %tmp5 to i64 + %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6 + %tmp8 = load i8, i8* %tmp7, align 1 + %tmp9 = add nsw i32 %v0, 1 + %tmp10 = add nsw i32 %v1, %tmp9 + %tmp11 = sext i32 %tmp10 to i64 + %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11 + %tmp13 = load i8, i8* %tmp12, align 1 + %tmp14 = add nsw i32 %v0, 2 + %tmp15 = add nsw i32 %v1, %tmp14 + %tmp16 = sext i32 %tmp15 to i64 + %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16 + %tmp18 = load i8, i8* %tmp17, align 1 + %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0 + %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1 + %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2 + %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3 + store <4 x i8> %tmp22, <4 x i8>* %dst + ret void +} + +define void @ld_v4i8_add_nuw(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) { +; CHECK-LABEL: @ld_v4i8_add_nuw( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP:%.*]] = add nuw i32 [[V0:%.*]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i32 [[V1:%.*]], [[TMP]] +; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>* +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1 +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0 +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1 +; CHECK-NEXT: [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2 +; CHECK-NEXT: [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3 +; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]] +; CHECK-NEXT: ret void +; +bb: + %tmp = add nuw i32 %v0, -1 + %tmp1 = add nuw i32 %v1, %tmp + %tmp2 = zext i32 %tmp1 to i64 + %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2 + %tmp4 = load i8, i8* %tmp3, align 1 + %tmp5 = add nuw i32 %v1, %v0 + %tmp6 = zext i32 %tmp5 to i64 + %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6 + %tmp8 = load i8, i8* %tmp7, align 1 + %tmp9 = add nuw i32 %v0, 1 + %tmp10 = add nuw i32 %v1, %tmp9 + %tmp11 = zext i32 %tmp10 to i64 + %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11 + %tmp13 = load i8, i8* %tmp12, align 1 + %tmp14 = add nuw i32 %v0, 2 + %tmp15 = add nuw i32 %v1, %tmp14 + %tmp16 = zext i32 %tmp15 to i64 + %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16 + %tmp18 = load i8, i8* %tmp17, align 1 + %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0 + %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1 + %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2 + %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3 + store <4 x i8> %tmp22, <4 x i8>* %dst + ret void +} + +; Make sure we don't vectorize the loads below because the source of +; sext instructions doesn't have the nsw flag. + +define void @ld_v4i8_add_not_safe(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) { +; CHECK-LABEL: @ld_v4i8_add_not_safe( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[V1:%.*]], [[TMP]] +; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = load i8, i8* [[TMP3]], align 1 +; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[V1]], [[V0]] +; CHECK-NEXT: [[TMP6:%.*]] = sext i32 [[TMP5]] to i64 +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP6]] +; CHECK-NEXT: [[TMP8:%.*]] = load i8, i8* [[TMP7]], align 1 +; CHECK-NEXT: [[TMP9:%.*]] = add nsw i32 [[V0]], 1 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[V1]], [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = sext i32 [[TMP10]] to i64 +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP11]] +; CHECK-NEXT: [[TMP13:%.*]] = load i8, i8* [[TMP12]], align 1 +; CHECK-NEXT: [[TMP14:%.*]] = add nsw i32 [[V0]], 2 +; CHECK-NEXT: [[TMP15:%.*]] = add i32 [[V1]], [[TMP14]] +; CHECK-NEXT: [[TMP16:%.*]] = sext i32 [[TMP15]] to i64 +; CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP16]] +; CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[TMP17]], align 1 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP4]], i32 0 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP8]], i32 1 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP13]], i32 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP18]], i32 3 +; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]] +; CHECK-NEXT: ret void +; +bb: + %tmp = add nsw i32 %v0, -1 + %tmp1 = add i32 %v1, %tmp + %tmp2 = sext i32 %tmp1 to i64 + %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2 + %tmp4 = load i8, i8* %tmp3, align 1 + %tmp5 = add i32 %v1, %v0 + %tmp6 = sext i32 %tmp5 to i64 + %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6 + %tmp8 = load i8, i8* %tmp7, align 1 + %tmp9 = add nsw i32 %v0, 1 + %tmp10 = add i32 %v1, %tmp9 + %tmp11 = sext i32 %tmp10 to i64 + %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11 + %tmp13 = load i8, i8* %tmp12, align 1 + %tmp14 = add nsw i32 %v0, 2 + %tmp15 = add i32 %v1, %tmp14 + %tmp16 = sext i32 %tmp15 to i64 + %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16 + %tmp18 = load i8, i8* %tmp17, align 1 + %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0 + %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1 + %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2 + %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3 + store <4 x i8> %tmp22, <4 x i8>* %dst + ret void +}