Index: llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -384,6 +384,78 @@ return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth); } +static bool checkNoWrapFlags(Instruction *I, bool Signed) { + BinaryOperator *BinOpI = cast(I); + return (Signed && BinOpI->hasNoSignedWrap()) || + (!Signed && BinOpI->hasNoUnsignedWrap()); +} +static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA, + unsigned MatchinOperandA, + Instruction *AddOpB, + unsigned MatchinOperandB, bool Signed) { + // If both OpA and OpB is an add with NSW/NUW and with + // one of the operands being the same, we can guarantee that the + // transformation is safe if we can prove that OpA won't overflow when + // IdxDiff added to the other operand of OpA. + // For example: + // %tmp7 = add nsw i32 %tmp2, %v0 + // %tmp8 = sext i32 %tmp7 to i64 + // ... + // %tmp11 = add nsw i32 %v0, 1 + // %tmp12 = add nsw i32 %tmp2, %tmp11 + // %tmp13 = sext i32 %tmp12 to i64 + // + // Both %tmp7 and %tmp2 has the nsw flag and the first operand + // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow + // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the + // nsw flag. + assert(AddOpA->getOpcode() == Instruction::Add && + AddOpB->getOpcode() == Instruction::Add && + checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed)); + unsigned OtherOperandA = 1 - MatchinOperandA; + unsigned OtherOperandB = 1 - MatchinOperandB; + bool Safe = false; + if (AddOpA->getOperand(MatchinOperandA) == + AddOpB->getOperand(MatchinOperandB)) { + Value *RHSA = AddOpA->getOperand(OtherOperandA); + Value *RHSB = AddOpB->getOperand(OtherOperandB); + Instruction *OpRHSA = dyn_cast(RHSA); + Instruction *OpRHSB = dyn_cast(RHSB); + // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`. + if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add && + checkNoWrapFlags(OpRHSB, Signed) && + isa(OpRHSB->getOperand(1))) { + int64_t CstVal = cast(OpRHSB->getOperand(1))->getSExtValue(); + if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal) + Safe = true; + } + // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`. + if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add && + checkNoWrapFlags(OpRHSA, Signed) && + isa(OpRHSA->getOperand(1))) { + int64_t CstVal = cast(OpRHSA->getOperand(1))->getSExtValue(); + if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal) + Safe = true; + } + // Match `x +nsw/nuw (y +nsw/nuw c)` and + // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`. + if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add && + OpRHSB->getOpcode() == Instruction::Add && + checkNoWrapFlags(OpRHSA, Signed) && checkNoWrapFlags(OpRHSB, Signed) && + isa(OpRHSA->getOperand(1)) && + isa(OpRHSB->getOperand(1))) { + int64_t CstValA = + cast(OpRHSA->getOperand(1))->getSExtValue(); + int64_t CstValB = + cast(OpRHSB->getOperand(1))->getSExtValue(); + if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) && + IdxDiff.getSExtValue() == (CstValB - CstValA)) + Safe = true; + } + } + return Safe; +} + bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta, unsigned Depth) const { @@ -438,73 +510,30 @@ // Now we need to prove that adding IdxDiff to ValA won't overflow. bool Safe = false; - auto CheckFlags = [](Instruction *I, bool Signed) { - BinaryOperator *BinOpI = cast(I); - return (Signed && BinOpI->hasNoSignedWrap()) || - (!Signed && BinOpI->hasNoUnsignedWrap()); - }; // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to // ValA, we're okay. if (OpB->getOpcode() == Instruction::Add && isa(OpB->getOperand(1)) && IdxDiff.sle(cast(OpB->getOperand(1))->getSExtValue()) && - CheckFlags(OpB, Signed)) + checkNoWrapFlags(OpB, Signed)) Safe = true; - // Second attempt: If both OpA and OpB is an add with NSW/NUW and with - // the same LHS operand, we can guarantee that the transformation is safe - // if we can prove that OpA won't overflow when IdxDiff added to the RHS - // of OpA. - // For example: - // %tmp7 = add nsw i32 %tmp2, %v0 - // %tmp8 = sext i32 %tmp7 to i64 - // ... - // %tmp11 = add nsw i32 %v0, 1 - // %tmp12 = add nsw i32 %tmp2, %tmp11 - // %tmp13 = sext i32 %tmp12 to i64 - // - // Both %tmp7 and %tmp2 has the nsw flag and the first operand - // is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow - // because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the - // nsw flag. + // Second attempt: check if we have eligible add NSW/NUW instruction + // sequinces. OpA = dyn_cast(ValA); if (!Safe && OpA && OpA->getOpcode() == Instruction::Add && - OpB->getOpcode() == Instruction::Add && - OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) && - CheckFlags(OpB, Signed)) { - Value *RHSA = OpA->getOperand(1); - Value *RHSB = OpB->getOperand(1); - Instruction *OpRHSA = dyn_cast(RHSA); - Instruction *OpRHSB = dyn_cast(RHSB); - // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`. - if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add && - CheckFlags(OpRHSB, Signed) && isa(OpRHSB->getOperand(1))) { - int64_t CstVal = cast(OpRHSB->getOperand(1))->getSExtValue(); - if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal) - Safe = true; - } - // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`. - if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add && - CheckFlags(OpRHSA, Signed) && isa(OpRHSA->getOperand(1))) { - int64_t CstVal = cast(OpRHSA->getOperand(1))->getSExtValue(); - if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal) - Safe = true; - } - // Match `x +nsw/nuw (y +nsw/nuw c)` and - // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`. - if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add && - OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) && - CheckFlags(OpRHSB, Signed) && isa(OpRHSA->getOperand(1)) && - isa(OpRHSB->getOperand(1))) { - int64_t CstValA = - cast(OpRHSA->getOperand(1))->getSExtValue(); - int64_t CstValB = - cast(OpRHSB->getOperand(1))->getSExtValue(); - if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) && - IdxDiff.getSExtValue() == (CstValB - CstValA)) - Safe = true; - } + OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(OpA, Signed) && + checkNoWrapFlags(OpB, Signed)) { + // In the checks below a matching operand in OpA and OpB is + // an operand which is the same in those two instructions. + // Below we account for possible orders of the operands of + // these add instructions. + for (unsigned MatchingOperandA : {0, 1}) + for (unsigned MatchingOperandB : {0, 1}) + if (!Safe) + Safe = checkIfSafeAddSequence(IdxDiff, OpA, MatchingOperandA, OpB, + MatchingOperandB, Signed); } unsigned BitWidth = ValA->getType()->getScalarSizeInBits(); Index: llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll =================================================================== --- llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll +++ llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll @@ -104,6 +104,104 @@ ret void } +; Apply different operand orders for the nested add sequences +define void @ld_v4i8_add_nsw_operand_orders(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) { +; CHECK-LABEL: @ld_v4i8_add_nsw_operand_orders( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = add nsw i32 [[V1:%.*]], [[TMP]] +; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>* +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1 +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0 +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1 +; CHECK-NEXT: [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2 +; CHECK-NEXT: [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3 +; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]] +; CHECK-NEXT: ret void +; +bb: + %tmp = add nsw i32 %v0, -1 + %tmp1 = add nsw i32 %v1, %tmp + %tmp2 = sext i32 %tmp1 to i64 + %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2 + %tmp4 = load i8, i8* %tmp3, align 1 + %tmp5 = add nsw i32 %v0, %v1 + %tmp6 = sext i32 %tmp5 to i64 + %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6 + %tmp8 = load i8, i8* %tmp7, align 1 + %tmp9 = add nsw i32 %v0, 1 + %tmp10 = add nsw i32 %tmp9, %v1 + %tmp11 = sext i32 %tmp10 to i64 + %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11 + %tmp13 = load i8, i8* %tmp12, align 1 + %tmp14 = add nsw i32 %v0, 2 + %tmp15 = add nsw i32 %v1, %tmp14 + %tmp16 = sext i32 %tmp15 to i64 + %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16 + %tmp18 = load i8, i8* %tmp17, align 1 + %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0 + %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1 + %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2 + %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3 + store <4 x i8> %tmp22, <4 x i8>* %dst + ret void +} + +; Apply different operand orders for the nested add sequences +define void @ld_v4i8_add_nuw_operand_orders(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) { +; CHECK-LABEL: @ld_v4i8_add_nuw_operand_orders( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP:%.*]] = add nuw i32 [[V0:%.*]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = add nuw i32 [[V1:%.*]], [[TMP]] +; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>* +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1 +; CHECK-NEXT: [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0 +; CHECK-NEXT: [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1 +; CHECK-NEXT: [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2 +; CHECK-NEXT: [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0 +; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2 +; CHECK-NEXT: [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3 +; CHECK-NEXT: store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]] +; CHECK-NEXT: ret void +; +bb: + %tmp = add nuw i32 %v0, -1 + %tmp1 = add nuw i32 %v1, %tmp + %tmp2 = zext i32 %tmp1 to i64 + %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2 + %tmp4 = load i8, i8* %tmp3, align 1 + %tmp5 = add nuw i32 %v0, %v1 + %tmp6 = zext i32 %tmp5 to i64 + %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6 + %tmp8 = load i8, i8* %tmp7, align 1 + %tmp9 = add nuw i32 %v0, 1 + %tmp10 = add nuw i32 %tmp9, %v1 + %tmp11 = zext i32 %tmp10 to i64 + %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11 + %tmp13 = load i8, i8* %tmp12, align 1 + %tmp14 = add nuw i32 %v0, 2 + %tmp15 = add nuw i32 %v1, %tmp14 + %tmp16 = zext i32 %tmp15 to i64 + %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16 + %tmp18 = load i8, i8* %tmp17, align 1 + %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0 + %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1 + %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2 + %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3 + store <4 x i8> %tmp22, <4 x i8>* %dst + ret void +} + define void @ld_v4i8_add_known_bits(i32 %ind0, i32 %ind1, i8* %src, <4 x i8>* %dst) { ; CHECK-LABEL: @ld_v4i8_add_known_bits( ; CHECK-NEXT: bb: