diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3088,20 +3088,22 @@ } unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { - unsigned N; - Type *EltTy; - auto *ST = dyn_cast(T); - if (ST) { - N = ST->getNumElements(); - EltTy = *ST->element_begin(); - } else { - N = cast(T)->getNumElements(); - EltTy = cast(T)->getElementType(); - } - - if (auto *VT = dyn_cast(EltTy)) { - EltTy = VT->getElementType(); - N *= VT->getNumElements(); + unsigned N = 1; + Type *EltTy = T; + + while (isa(EltTy)) { + if (auto *ST = dyn_cast(EltTy)) { + // Check that struct is homogeneous. + for (const auto *Ty : ST->elements()) + if (Ty != *ST->element_begin()) + return 0; + N *= ST->getNumElements(); + EltTy = *ST->element_begin(); + } else { + auto *SeqT = dyn_cast(EltTy); + N *= SeqT->getNumElements(); + EltTy = SeqT->getElementType(); + } } if (!isValidElementType(EltTy)) @@ -3109,12 +3111,6 @@ uint64_t VTSize = DL.getTypeStoreSizeInBits(VectorType::get(EltTy, N)); if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize || VTSize != DL.getTypeStoreSizeInBits(T)) return 0; - if (ST) { - // Check that struct is homogeneous. - for (const auto *Ty : ST->elements()) - if (Ty != *ST->element_begin()) - return 0; - } return N; } @@ -6940,57 +6936,50 @@ /// %rb = insertelement <4 x float> %ra, float %s1, i32 1 /// %rc = insertelement <4 x float> %rb, float %s2, i32 2 /// %rd = insertelement <4 x float> %rc, float %s3, i32 3 -/// starting from the last insertelement instruction. +/// starting from the last insertelement or insertvalue instruction. /// -/// Returns true if it matches -static bool findBuildVector(InsertElementInst *LastInsertElem, - TargetTransformInfo *TTI, - SmallVectorImpl &BuildVectorOpds, - int &UserCost) { - UserCost = 0; - Value *V = nullptr; - do { - if (auto *CI = dyn_cast(LastInsertElem->getOperand(2))) { - UserCost += TTI->getVectorInstrCost(Instruction::InsertElement, - LastInsertElem->getType(), - CI->getZExtValue()); - } - BuildVectorOpds.push_back(LastInsertElem->getOperand(1)); - V = LastInsertElem->getOperand(0); - if (isa(V)) - break; - LastInsertElem = dyn_cast(V); - if (!LastInsertElem || !LastInsertElem->hasOneUse()) - return false; - } while (true); - std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end()); - return true; -} - -/// Like findBuildVector, but looks for construction of aggregate. -/// Accepts homegeneous aggregate of vectors like { <2 x float>, <2 x float> }. +/// Also recognize homegeneous aggregates like {<2 x float>, <2 x float>}, +/// {{float, float}, {float, float}}, [2 x {float, float}] and so on. +/// See llvm/test/Transforms/SLPVectorizer/X86/pr42022.ll for examples. +/// +/// Assume LastInsertInst is of InsertElementInst or InsertValueInst type. /// /// \return true if it matches. -static bool findBuildAggregate(InsertValueInst *IV, TargetTransformInfo *TTI, +static bool findBuildAggregate(Value *LastInsertInst, TargetTransformInfo *TTI, SmallVectorImpl &BuildVectorOpds, int &UserCost) { UserCost = 0; do { - if (auto *IE = dyn_cast(IV->getInsertedValueOperand())) { + Value *InsertedOperand; + if (auto *IE = dyn_cast(LastInsertInst)) { + InsertedOperand = IE->getOperand(1); + LastInsertInst = IE->getOperand(0); + if (auto *CI = dyn_cast(IE->getOperand(2))) { + UserCost += TTI->getVectorInstrCost(Instruction::InsertElement, + IE->getType(), + CI->getZExtValue()); + } + } else if (auto *IV = dyn_cast(LastInsertInst)) { + InsertedOperand = IV->getInsertedValueOperand(); + LastInsertInst = IV->getAggregateOperand(); + } + if (isa(InsertedOperand) || + isa(InsertedOperand)) { int TmpUserCost; - SmallVector TmpBuildVectorOpds; - if (!findBuildVector(IE, TTI, TmpBuildVectorOpds, TmpUserCost)) + SmallVector TmpBuildVectorOpds; + if (!findBuildAggregate(InsertedOperand, TTI, TmpBuildVectorOpds, TmpUserCost)) return false; BuildVectorOpds.append(TmpBuildVectorOpds.rbegin(), TmpBuildVectorOpds.rend()); UserCost += TmpUserCost; } else { - BuildVectorOpds.push_back(IV->getInsertedValueOperand()); + BuildVectorOpds.push_back(InsertedOperand); } - Value *V = IV->getAggregateOperand(); - if (isa(V)) + if (isa(LastInsertInst)) break; - IV = dyn_cast(V); - if (!IV || !IV->hasOneUse()) + if ((!isa(LastInsertInst) && + !isa(LastInsertInst)) || + !LastInsertInst->hasOneUse() + ) return false; } while (true); std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end()); @@ -7177,7 +7166,7 @@ BasicBlock *BB, BoUpSLP &R) { int UserCost; SmallVector BuildVectorOpds; - if (!findBuildVector(IEI, TTI, BuildVectorOpds, UserCost) || + if (!findBuildAggregate(IEI, TTI, BuildVectorOpds, UserCost) || (llvm::all_of(BuildVectorOpds, [](Value *V) { return isa(V); }) && isShuffle(BuildVectorOpds))) diff --git a/llvm/test/Transforms/SLPVectorizer/X86/pr42022.ll b/llvm/test/Transforms/SLPVectorizer/X86/pr42022.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/pr42022.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/pr42022.ll @@ -55,21 +55,20 @@ define [2 x %StructTy] @ArrayOfStruct(float *%Ptr) { ; CHECK-LABEL: @ArrayOfStruct( ; CHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds float, float* [[PTR:%.*]], i64 0 -; CHECK-NEXT: [[L0:%.*]] = load float, float* [[GEP0]] ; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 1 -; CHECK-NEXT: [[L1:%.*]] = load float, float* [[GEP1]] ; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 2 -; CHECK-NEXT: [[L2:%.*]] = load float, float* [[GEP2]] ; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 3 -; CHECK-NEXT: [[L3:%.*]] = load float, float* [[GEP3]] -; CHECK-NEXT: [[FADD0:%.*]] = fadd fast float [[L0]], 1.100000e+01 -; CHECK-NEXT: [[FADD1:%.*]] = fadd fast float [[L1]], 1.200000e+01 -; CHECK-NEXT: [[FADD2:%.*]] = fadd fast float [[L2]], 1.300000e+01 -; CHECK-NEXT: [[FADD3:%.*]] = fadd fast float [[L3]], 1.400000e+01 -; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCTTY:%.*]] undef, float [[FADD0]], 0 -; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCTTY]] %StructIn0, float [[FADD1]], 1 -; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCTTY]] undef, float [[FADD2]], 0 -; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCTTY]] %StructIn2, float [[FADD3]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[GEP0]] to <4 x float>* +; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <4 x float> [[TMP2]], +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[TMP3]], i32 0 +; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCTTY:%.*]] undef, float [[TMP4]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x float> [[TMP3]], i32 1 +; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCTTY]] %StructIn0, float [[TMP5]], 1 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> [[TMP3]], i32 2 +; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCTTY]] undef, float [[TMP6]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x float> [[TMP3]], i32 3 +; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCTTY]] %StructIn2, float [[TMP7]], 1 ; CHECK-NEXT: [[RET0:%.*]] = insertvalue [2 x %StructTy] undef, [[STRUCTTY]] %StructIn1, 0 ; CHECK-NEXT: [[RET1:%.*]] = insertvalue [2 x %StructTy] [[RET0]], [[STRUCTTY]] %StructIn3, 1 ; CHECK-NEXT: ret [2 x %StructTy] [[RET1]] @@ -102,21 +101,20 @@ define {%StructTy, %StructTy} @StructOfStruct(float *%Ptr) { ; CHECK-LABEL: @StructOfStruct( ; CHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds float, float* [[PTR:%.*]], i64 0 -; CHECK-NEXT: [[L0:%.*]] = load float, float* [[GEP0]] ; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 1 -; CHECK-NEXT: [[L1:%.*]] = load float, float* [[GEP1]] ; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 2 -; CHECK-NEXT: [[L2:%.*]] = load float, float* [[GEP2]] ; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds float, float* [[PTR]], i64 3 -; CHECK-NEXT: [[L3:%.*]] = load float, float* [[GEP3]] -; CHECK-NEXT: [[FADD0:%.*]] = fadd fast float [[L0]], 1.100000e+01 -; CHECK-NEXT: [[FADD1:%.*]] = fadd fast float [[L1]], 1.200000e+01 -; CHECK-NEXT: [[FADD2:%.*]] = fadd fast float [[L2]], 1.300000e+01 -; CHECK-NEXT: [[FADD3:%.*]] = fadd fast float [[L3]], 1.400000e+01 -; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCTTY:%.*]] undef, float [[FADD0]], 0 -; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCTTY]] %StructIn0, float [[FADD1]], 1 -; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCTTY]] undef, float [[FADD2]], 0 -; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCTTY]] %StructIn2, float [[FADD3]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[GEP0]] to <4 x float>* +; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <4 x float> [[TMP2]], +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[TMP3]], i32 0 +; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCTTY:%.*]] undef, float [[TMP4]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x float> [[TMP3]], i32 1 +; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCTTY]] %StructIn0, float [[TMP5]], 1 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> [[TMP3]], i32 2 +; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCTTY]] undef, float [[TMP6]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <4 x float> [[TMP3]], i32 3 +; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCTTY]] %StructIn2, float [[TMP7]], 1 ; CHECK-NEXT: [[RET0:%.*]] = insertvalue { [[STRUCTTY]], [[STRUCTTY]] } undef, [[STRUCTTY]] %StructIn1, 0 ; CHECK-NEXT: [[RET1:%.*]] = insertvalue { [[STRUCTTY]], [[STRUCTTY]] } [[RET0]], [[STRUCTTY]] %StructIn3, 1 ; CHECK-NEXT: ret { [[STRUCTTY]], [[STRUCTTY]] } [[RET1]] @@ -196,37 +194,32 @@ define {%Struct2Ty, %Struct2Ty} @StructOfStructOfStruct(i16 *%Ptr) { ; CHECK-LABEL: @StructOfStructOfStruct( ; CHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds i16, i16* [[PTR:%.*]], i64 0 -; CHECK-NEXT: [[L0:%.*]] = load i16, i16* [[GEP0]] ; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 1 -; CHECK-NEXT: [[L1:%.*]] = load i16, i16* [[GEP1]] ; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 2 -; CHECK-NEXT: [[L2:%.*]] = load i16, i16* [[GEP2]] ; CHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 3 -; CHECK-NEXT: [[L3:%.*]] = load i16, i16* [[GEP3]] ; CHECK-NEXT: [[GEP4:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 4 -; CHECK-NEXT: [[L4:%.*]] = load i16, i16* [[GEP4]] ; CHECK-NEXT: [[GEP5:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 5 -; CHECK-NEXT: [[L5:%.*]] = load i16, i16* [[GEP5]] ; CHECK-NEXT: [[GEP6:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 6 -; CHECK-NEXT: [[L6:%.*]] = load i16, i16* [[GEP6]] ; CHECK-NEXT: [[GEP7:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 7 -; CHECK-NEXT: [[L7:%.*]] = load i16, i16* [[GEP7]] -; CHECK-NEXT: [[FADD0:%.*]] = add i16 [[L0]], 1 -; CHECK-NEXT: [[FADD1:%.*]] = add i16 [[L1]], 2 -; CHECK-NEXT: [[FADD2:%.*]] = add i16 [[L2]], 3 -; CHECK-NEXT: [[FADD3:%.*]] = add i16 [[L3]], 4 -; CHECK-NEXT: [[FADD4:%.*]] = add i16 [[L4]], 5 -; CHECK-NEXT: [[FADD5:%.*]] = add i16 [[L5]], 6 -; CHECK-NEXT: [[FADD6:%.*]] = add i16 [[L6]], 7 -; CHECK-NEXT: [[FADD7:%.*]] = add i16 [[L7]], 8 -; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCT1TY:%.*]] undef, i16 [[FADD0]], 0 -; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCT1TY]] %StructIn0, i16 [[FADD1]], 1 -; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[FADD2]], 0 -; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCT1TY]] %StructIn2, i16 [[FADD3]], 1 -; CHECK-NEXT: [[STRUCTIN4:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[FADD4]], 0 -; CHECK-NEXT: [[STRUCTIN5:%.*]] = insertvalue [[STRUCT1TY]] %StructIn4, i16 [[FADD5]], 1 -; CHECK-NEXT: [[STRUCTIN6:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[FADD6]], 0 -; CHECK-NEXT: [[STRUCTIN7:%.*]] = insertvalue [[STRUCT1TY]] %StructIn6, i16 [[FADD7]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i16* [[GEP0]] to <8 x i16>* +; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i16>, <8 x i16>* [[TMP1]], align 2 +; CHECK-NEXT: [[TMP3:%.*]] = add <8 x i16> [[TMP2]], +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <8 x i16> [[TMP3]], i32 0 +; CHECK-NEXT: [[STRUCTIN0:%.*]] = insertvalue [[STRUCT1TY:%.*]] undef, i16 [[TMP4]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i16> [[TMP3]], i32 1 +; CHECK-NEXT: [[STRUCTIN1:%.*]] = insertvalue [[STRUCT1TY]] %StructIn0, i16 [[TMP5]], 1 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <8 x i16> [[TMP3]], i32 2 +; CHECK-NEXT: [[STRUCTIN2:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[TMP6]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <8 x i16> [[TMP3]], i32 3 +; CHECK-NEXT: [[STRUCTIN3:%.*]] = insertvalue [[STRUCT1TY]] %StructIn2, i16 [[TMP7]], 1 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <8 x i16> [[TMP3]], i32 4 +; CHECK-NEXT: [[STRUCTIN4:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[TMP8]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <8 x i16> [[TMP3]], i32 5 +; CHECK-NEXT: [[STRUCTIN5:%.*]] = insertvalue [[STRUCT1TY]] %StructIn4, i16 [[TMP9]], 1 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <8 x i16> [[TMP3]], i32 6 +; CHECK-NEXT: [[STRUCTIN6:%.*]] = insertvalue [[STRUCT1TY]] undef, i16 [[TMP10]], 0 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <8 x i16> [[TMP3]], i32 7 +; CHECK-NEXT: [[STRUCTIN7:%.*]] = insertvalue [[STRUCT1TY]] %StructIn6, i16 [[TMP11]], 1 ; CHECK-NEXT: [[STRUCT2IN0:%.*]] = insertvalue [[STRUCT2TY:%.*]] undef, [[STRUCT1TY]] %StructIn1, 0 ; CHECK-NEXT: [[STRUCT2IN1:%.*]] = insertvalue [[STRUCT2TY]] %Struct2In0, [[STRUCT1TY]] %StructIn3, 1 ; CHECK-NEXT: [[STRUCT2IN2:%.*]] = insertvalue [[STRUCT2TY]] undef, [[STRUCT1TY]] %StructIn5, 0