Index: llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -439,20 +439,78 @@ // Now we need to prove that adding IdxDiff to ValA won't overflow. bool Safe = false; + auto CheckFlags = [](Instruction *I, bool Signed) { + BinaryOperator *BinOpI = cast(I); + return (Signed && BinOpI->hasNoSignedWrap()) || + (!Signed && BinOpI->hasNoUnsignedWrap()); + }; + // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to // ValA, we're okay. if (OpB->getOpcode() == Instruction::Add && isa(OpB->getOperand(1)) && - IdxDiff.sle(cast(OpB->getOperand(1))->getSExtValue())) { - if (Signed) - Safe = cast(OpB)->hasNoSignedWrap(); - else - Safe = cast(OpB)->hasNoUnsignedWrap(); + IdxDiff.sle(cast(OpB->getOperand(1))->getSExtValue()) && + CheckFlags(OpB, Signed)) + Safe = true; + + // Second attempt: If both OpA and OpB is an add with NSW/NUW and with + // the same LHS operand, we can guarantee that the transformation is safe + // if we can prove that OpA won't overflow when IdxDiff added to the RHS + // of OpA. + // For example: + // %tmp7 = add nsw i32 %tmp2, %v0 + // %tmp8 = sext i32 %tmp7 to i64 + // ... + // %tmp11 = add nsw i32 %v0, 1 + // %tmp12 = add nsw i32 %tmp2, %tmp11 + // %tmp13 = sext i32 %tmp12 to i64 + // + // Both %tmp7 and %tmp2 has the nsw flag and the first operand + // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow + // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the + // nsw flag. + OpA = dyn_cast(ValA); + if (!Safe && OpA && OpA->getOpcode() == Instruction::Add && + OpB->getOpcode() == Instruction::Add && + OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) && + CheckFlags(OpB, Signed)) { + Value *RHSA = OpA->getOperand(1); + Value *RHSB = OpB->getOperand(1); + Instruction *OpRHSA = dyn_cast(RHSA); + Instruction *OpRHSB = dyn_cast(RHSB); + // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`. + if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add && + CheckFlags(OpRHSB, Signed) && isa(OpRHSB->getOperand(1))) { + int64_t CstVal = cast(OpRHSB->getOperand(1))->getSExtValue(); + if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal) + Safe = true; + } + // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`. + if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add && + CheckFlags(OpRHSA, Signed) && isa(OpRHSA->getOperand(1))) { + int64_t CstVal = cast(OpRHSA->getOperand(1))->getSExtValue(); + if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal) + Safe = true; + } + // Match `x +nsw/nuw (y +nsw/nuw c)` and + // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`. + if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add && + OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) && + CheckFlags(OpRHSB, Signed) && isa(OpRHSA->getOperand(1)) && + isa(OpRHSB->getOperand(1))) { + int64_t CstValA = + cast(OpRHSA->getOperand(1))->getSExtValue(); + int64_t CstValB = + cast(OpRHSB->getOperand(1))->getSExtValue(); + if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) && + IdxDiff.getSExtValue() == (CstValB - CstValA)) + Safe = true; + } } unsigned BitWidth = ValA->getType()->getScalarSizeInBits(); - // Second attempt: + // Third attempt: // If all set bits of IdxDiff or any higher order bit other than the sign bit // are known to be zero in ValA, we can add Diff to it while guaranteeing no // overflow of any sort. Index: llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -o - -S -load-store-vectorizer %s | FileCheck %s + +; Make sure LoadStoreVectorizer vectorizes the loads below. +; In order to prove that the vectorization is safe, it tries to +; match nested adds and find an expression that adds a constant +; value to an existing index and the result doesn't overflow. + +target triple = "x86_64--" + +define void @load_v4s8_nested_add(i32 %v0, i32 %v1, i32 %v3, i8* %src, <4 x i8>* %dst) { +; CHECK-LABEL: @load_v4s8_nested_add( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = add nsw i32 [[V1:%.*]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = mul nsw i32 [[V3:%.*]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = add nsw i32 [[TMP2]], [[TMP]] +; CHECK-NEXT: [[TMP4:%.*]] = sext i32 [[TMP3]] to i64 +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP4]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP5]] to <4 x i8>* +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1 +; CHECK-NEXT: [[TMP61:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0 +; CHECK-NEXT: [[TMP102:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1 +; CHECK-NEXT: [[TMP153:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2 +; CHECK-NEXT: [[TMP204:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3 +; CHECK-NEXT: [[TMP7:%.*]] = add nsw i32 [[TMP2]], [[V0]] +; CHECK-NEXT: [[TMP8:%.*]] = sext i32 [[TMP7]] to i64 +; CHECK-NEXT: [[TMP11:%.*]] = add nsw i32 [[V0]], 1 +; CHECK-NEXT: [[TMP12:%.*]] = add nsw i32 [[TMP2]], [[TMP11]] +; CHECK-NEXT: [[TMP13:%.*]] = sext i32 [[TMP12]] to i64 +; CHECK-NEXT: [[TMP16:%.*]] = add nsw i32 [[V0]], 2 +; CHECK-NEXT: [[TMP17:%.*]] = add nsw i32 [[TMP2]], [[TMP16]] +; CHECK-NEXT: [[TMP18:%.*]] = sext i32 [[TMP17]] to i64 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> undef, i8 [[TMP61]], i32 0 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP102]], i32 1 +; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x i8> [[TMP22]], i8 [[TMP153]], i32 2 +; CHECK-NEXT: [[TMP24:%.*]] = insertelement <4 x i8> [[TMP23]], i8 [[TMP204]], i32 3 +; CHECK-NEXT: store <4 x i8> [[TMP24]], <4 x i8>* [[DST:%.*]] +; CHECK-NEXT: ret void +; +bb: + %tmp = add nsw i32 %v0, -1 + %tmp1 = add nsw i32 %v1, -1 + %tmp2 = mul nsw i32 %v3, %tmp1 + %tmp3 = add nsw i32 %tmp2, %tmp + %tmp4 = sext i32 %tmp3 to i64 + %tmp5 = getelementptr inbounds i8, i8* %src, i64 %tmp4 + %tmp6 = load i8, i8* %tmp5, align 1 + %tmp7 = add nsw i32 %tmp2, %v0 + %tmp8 = sext i32 %tmp7 to i64 + %tmp9 = getelementptr inbounds i8, i8* %src, i64 %tmp8 + %tmp10 = load i8, i8* %tmp9, align 1 + %tmp11 = add nsw i32 %v0, 1 + %tmp12 = add nsw i32 %tmp2, %tmp11 + %tmp13 = sext i32 %tmp12 to i64 + %tmp14 = getelementptr inbounds i8, i8* %src, i64 %tmp13 + %tmp15 = load i8, i8* %tmp14, align 1 + %tmp16 = add nsw i32 %v0, 2 + %tmp17 = add nsw i32 %tmp2, %tmp16 + %tmp18 = sext i32 %tmp17 to i64 + %tmp19 = getelementptr inbounds i8, i8* %src, i64 %tmp18 + %tmp20 = load i8, i8* %tmp19, align 1 + %tmp21 = insertelement <4 x i8> undef, i8 %tmp6, i32 0 + %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp10, i32 1 + %tmp23 = insertelement <4 x i8> %tmp22, i8 %tmp15, i32 2 + %tmp24 = insertelement <4 x i8> %tmp23, i8 %tmp20, i32 3 + store <4 x i8> %tmp24, <4 x i8>* %dst + ret void +}