diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -384,6 +384,81 @@ return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth); } +static bool checkNoWrapFlags(Instruction *I, bool Signed) { + BinaryOperator *BinOpI = cast(I); + return (Signed && BinOpI->hasNoSignedWrap()) || + (!Signed && BinOpI->hasNoUnsignedWrap()); +} + +static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA, + unsigned MatchingOpIdxA, Instruction *AddOpB, + unsigned MatchingOpIdxB, bool Signed) { + // If both OpA and OpB is an add with NSW/NUW and with + // one of the operands being the same, we can guarantee that the + // transformation is safe if we can prove that OpA won't overflow when + // IdxDiff added to the other operand of OpA. + // For example: + // %tmp7 = add nsw i32 %tmp2, %v0 + // %tmp8 = sext i32 %tmp7 to i64 + // ... + // %tmp11 = add nsw i32 %v0, 1 + // %tmp12 = add nsw i32 %tmp2, %tmp11 + // %tmp13 = sext i32 %tmp12 to i64 + // + // Both %tmp7 and %tmp2 has the nsw flag and the first operand + // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow + // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the + // nsw flag. + assert(AddOpA->getOpcode() == Instruction::Add && + AddOpB->getOpcode() == Instruction::Add && + checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed)); + if (AddOpA->getOperand(MatchingOpIdxA) == + AddOpB->getOperand(MatchingOpIdxB)) { + Value *OtherOperandA = AddOpA->getOperand(MatchingOpIdxA == 1 ? 0 : 1); + Value *OtherOperandB = AddOpB->getOperand(MatchingOpIdxB == 1 ? 0 : 1); + Instruction *OtherInstrA = dyn_cast(OtherOperandA); + Instruction *OtherInstrB = dyn_cast(OtherOperandB); + // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`. + if (OtherInstrB && OtherInstrB->getOpcode() == Instruction::Add && + checkNoWrapFlags(OtherInstrB, Signed) && + isa(OtherInstrB->getOperand(1))) { + int64_t CstVal = + cast(OtherInstrB->getOperand(1))->getSExtValue(); + if (OtherInstrB->getOperand(0) == OtherOperandA && + IdxDiff.getSExtValue() == CstVal) + return true; + } + // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`. + if (OtherInstrA && OtherInstrA->getOpcode() == Instruction::Add && + checkNoWrapFlags(OtherInstrA, Signed) && + isa(OtherInstrA->getOperand(1))) { + int64_t CstVal = + cast(OtherInstrA->getOperand(1))->getSExtValue(); + if (OtherInstrA->getOperand(0) == OtherOperandB && + IdxDiff.getSExtValue() == -CstVal) + return true; + } + // Match `x +nsw/nuw (y +nsw/nuw c)` and + // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`. + if (OtherInstrA && OtherInstrB && + OtherInstrA->getOpcode() == Instruction::Add && + OtherInstrB->getOpcode() == Instruction::Add && + checkNoWrapFlags(OtherInstrA, Signed) && + checkNoWrapFlags(OtherInstrB, Signed) && + isa(OtherInstrA->getOperand(1)) && + isa(OtherInstrB->getOperand(1))) { + int64_t CstValA = + cast(OtherInstrA->getOperand(1))->getSExtValue(); + int64_t CstValB = + cast(OtherInstrB->getOperand(1))->getSExtValue(); + if (OtherInstrA->getOperand(0) == OtherInstrB->getOperand(0) && + IdxDiff.getSExtValue() == (CstValB - CstValA)) + return true; + } + } + return false; +} + bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta, unsigned Depth) const { @@ -438,73 +513,30 @@ // Now we need to prove that adding IdxDiff to ValA won't overflow. bool Safe = false; - auto CheckFlags = [](Instruction *I, bool Signed) { - BinaryOperator *BinOpI = cast(I); - return (Signed && BinOpI->hasNoSignedWrap()) || - (!Signed && BinOpI->hasNoUnsignedWrap()); - }; // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to // ValA, we're okay. if (OpB->getOpcode() == Instruction::Add && isa(OpB->getOperand(1)) && IdxDiff.sle(cast(OpB->getOperand(1))->getSExtValue()) && - CheckFlags(OpB, Signed)) + checkNoWrapFlags(OpB, Signed)) Safe = true; - // Second attempt: If both OpA and OpB is an add with NSW/NUW and with - // the same LHS operand, we can guarantee that the transformation is safe - // if we can prove that OpA won't overflow when IdxDiff added to the RHS - // of OpA. - // For example: - // %tmp7 = add nsw i32 %tmp2, %v0 - // %tmp8 = sext i32 %tmp7 to i64 - // ... - // %tmp11 = add nsw i32 %v0, 1 - // %tmp12 = add nsw i32 %tmp2, %tmp11 - // %tmp13 = sext i32 %tmp12 to i64 - // - // Both %tmp7 and %tmp2 has the nsw flag and the first operand - // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow - // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the - // nsw flag. + // Second attempt: check if we have eligible add NSW/NUW instruction + // sequences. OpA = dyn_cast(ValA); if (!Safe && OpA && OpA->getOpcode() == Instruction::Add && - OpB->getOpcode() == Instruction::Add && - OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) && - CheckFlags(OpB, Signed)) { - Value *RHSA = OpA->getOperand(1); - Value *RHSB = OpB->getOperand(1); - Instruction *OpRHSA = dyn_cast(RHSA); - Instruction *OpRHSB = dyn_cast(RHSB); - // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`. - if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add && - CheckFlags(OpRHSB, Signed) && isa(OpRHSB->getOperand(1))) { - int64_t CstVal = cast(OpRHSB->getOperand(1))->getSExtValue(); - if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal) - Safe = true; - } - // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`. - if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add && - CheckFlags(OpRHSA, Signed) && isa(OpRHSA->getOperand(1))) { - int64_t CstVal = cast(OpRHSA->getOperand(1))->getSExtValue(); - if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal) - Safe = true; - } - // Match `x +nsw/nuw (y +nsw/nuw c)` and - // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`. - if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add && - OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) && - CheckFlags(OpRHSB, Signed) && isa(OpRHSA->getOperand(1)) && - isa(OpRHSB->getOperand(1))) { - int64_t CstValA = - cast(OpRHSA->getOperand(1))->getSExtValue(); - int64_t CstValB = - cast(OpRHSB->getOperand(1))->getSExtValue(); - if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) && - IdxDiff.getSExtValue() == (CstValB - CstValA)) - Safe = true; - } + OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(OpA, Signed) && + checkNoWrapFlags(OpB, Signed)) { + // In the checks below a matching operand in OpA and OpB is + // an operand which is the same in those two instructions. + // Below we account for possible orders of the operands of + // these add instructions. + for (unsigned MatchingOpIdxA : {0, 1}) + for (unsigned MatchingOpIdxB : {0, 1}) + if (!Safe) + Safe = checkIfSafeAddSequence(IdxDiff, OpA, MatchingOpIdxA, OpB, + MatchingOpIdxB, Signed); } unsigned BitWidth = ValA->getType()->getScalarSizeInBits(); diff --git a/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll b/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll --- a/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll +++ b/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll @@ -104,6 +104,104 @@ ret void } +; Apply different operand orders for the nested add sequences +define void @ld_v4i8_add_nsw_operand_orders(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) { +; CHECK-LABEL: @ld_v4i8_add_nsw_operand_orders( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = add nsw i32 [[V1:%.*]], [[TMP]] +; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>* +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1 +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0 +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1 +; CHECK-NEXT: [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2 +; CHECK-NEXT: [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3 +; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]] +; CHECK-NEXT: ret void +; +bb: + %tmp = add nsw i32 %v0, -1 + %tmp1 = add nsw i32 %v1, %tmp + %tmp2 = sext i32 %tmp1 to i64 + %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2 + %tmp4 = load i8, i8* %tmp3, align 1 + %tmp5 = add nsw i32 %v0, %v1 + %tmp6 = sext i32 %tmp5 to i64 + %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6 + %tmp8 = load i8, i8* %tmp7, align 1 + %tmp9 = add nsw i32 %v0, 1 + %tmp10 = add nsw i32 %tmp9, %v1 + %tmp11 = sext i32 %tmp10 to i64 + %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11 + %tmp13 = load i8, i8* %tmp12, align 1 + %tmp14 = add nsw i32 %v0, 2 + %tmp15 = add nsw i32 %v1, %tmp14 + %tmp16 = sext i32 %tmp15 to i64 + %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16 + %tmp18 = load i8, i8* %tmp17, align 1 + %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0 + %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1 + %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2 + %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3 + store <4 x i8> %tmp22, <4 x i8>* %dst + ret void +} + +; Apply different operand orders for the nested add sequences +define void @ld_v4i8_add_nuw_operand_orders(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) { +; CHECK-LABEL: @ld_v4i8_add_nuw_operand_orders( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP:%.*]] = add nuw i32 [[V0:%.*]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i32 [[V1:%.*]], [[TMP]] +; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>* +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1 +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0 +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1 +; CHECK-NEXT: [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2 +; CHECK-NEXT: [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3 +; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]] +; CHECK-NEXT: ret void +; +bb: + %tmp = add nuw i32 %v0, -1 + %tmp1 = add nuw i32 %v1, %tmp + %tmp2 = zext i32 %tmp1 to i64 + %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2 + %tmp4 = load i8, i8* %tmp3, align 1 + %tmp5 = add nuw i32 %v0, %v1 + %tmp6 = zext i32 %tmp5 to i64 + %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6 + %tmp8 = load i8, i8* %tmp7, align 1 + %tmp9 = add nuw i32 %v0, 1 + %tmp10 = add nuw i32 %tmp9, %v1 + %tmp11 = zext i32 %tmp10 to i64 + %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11 + %tmp13 = load i8, i8* %tmp12, align 1 + %tmp14 = add nuw i32 %v0, 2 + %tmp15 = add nuw i32 %v1, %tmp14 + %tmp16 = zext i32 %tmp15 to i64 + %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16 + %tmp18 = load i8, i8* %tmp17, align 1 + %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0 + %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1 + %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2 + %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3 + store <4 x i8> %tmp22, <4 x i8>* %dst + ret void +} + define void @ld_v4i8_add_known_bits(i32 %ind0, i32 %ind1, i8* %src, <4 x i8>* %dst) { ; CHECK-LABEL: @ld_v4i8_add_known_bits( ; CHECK-NEXT: bb: