diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7075,6 +7075,30 @@ } // end anonymous namespace +static unsigned int getAggregateSize(Instruction *InsertInst) { + return 2; +} + +static unsigned int getOperandIndex(Instruction *InsertInst) { + auto *IE = dyn_cast(InsertInst); + if (IE) { + if (auto *CI = dyn_cast(IE->getOperand(2))) + return CI->getZExtValue(); + else + return 0; + } else { + auto *IV = cast(InsertInst); + Type *AggregateType = cast(InsertInst->getType()); + unsigned int OperandIndex = 0; + Type *CurrentType = IV->getType(); + for (auto i : IV->indices()) { + AggregateType->getElementType(i); + OperandIndex += i; + } + return OperandIndex; + } +} + /// Recognize construction of vectors like /// %ra = insertelement <4 x float> undef, float %s0, i32 0 /// %rb = insertelement <4 x float> %ra, float %s1, i32 1 @@ -7082,53 +7106,47 @@ /// %rd = insertelement <4 x float> %rc, float %s3, i32 3 /// starting from the last insertelement or insertvalue instruction. /// -/// Also recognize aggregates like {<2 x float>, <2 x float>}, +/// Also recognize homogeneous aggregates like {<2 x float>, <2 x float>}, /// {{float, float}, {float, float}}, [2 x {float, float}] and so on. /// See llvm/test/Transforms/SLPVectorizer/X86/pr42022.ll for examples. /// /// Assume LastInsertInst is of InsertElementInst or InsertValueInst type. /// /// \return true if it matches. -static bool findBuildAggregate(Value *LastInsertInst, TargetTransformInfo *TTI, +static bool findBuildAggregate(Instruction *LastInsertInst, + TargetTransformInfo *TTI, SmallVectorImpl &BuildVectorOpds, SmallVectorImpl &InsertElts) { assert((isa(LastInsertInst) || isa(LastInsertInst)) && "Expected insertelement or insertvalue instruction!"); + unsigned int aggregateSize = getAggregateSize(LastInsertInst); + // TODO: "!= 0" to prevent aggregate size computing? + if (BuildVectorOpds.size() != aggregateSize) { + BuildVectorOpds.resize(aggregateSize); + InsertElts.resize(aggregateSize); + } + do { - Value *InsertedOperand; - auto *IE = dyn_cast(LastInsertInst); - if (IE) { - InsertedOperand = IE->getOperand(1); - LastInsertInst = IE->getOperand(0); - } else { - auto *IV = cast(LastInsertInst); - InsertedOperand = IV->getInsertedValueOperand(); - LastInsertInst = IV->getAggregateOperand(); - } + Value *InsertedOperand = LastInsertInst->getOperand(1); + unsigned int OperandIndex = getOperandIndex(LastInsertInst); if (isa(InsertedOperand) || isa(InsertedOperand)) { - SmallVector TmpBuildVectorOpds; - SmallVector TmpInsertElts; - if (!findBuildAggregate(InsertedOperand, TTI, TmpBuildVectorOpds, - TmpInsertElts)) + if (!findBuildAggregate(dyn_cast(InsertedOperand), TTI, + BuildVectorOpds, InsertElts)) return false; - BuildVectorOpds.append(TmpBuildVectorOpds.rbegin(), - TmpBuildVectorOpds.rend()); - InsertElts.append(TmpInsertElts.rbegin(), TmpInsertElts.rend()); } else { - BuildVectorOpds.push_back(InsertedOperand); - InsertElts.push_back(IE); + BuildVectorOpds[OperandIndex] = InsertedOperand; + InsertElts[OperandIndex] = LastInsertInst; } - if (isa(LastInsertInst)) + if (isa(LastInsertInst->getOperand(0))) break; + LastInsertInst = dyn_cast(LastInsertInst->getOperand(0)); if ((!isa(LastInsertInst) && !isa(LastInsertInst)) || !LastInsertInst->hasOneUse()) return false; } while (true); - std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end()); - std::reverse(InsertElts.begin(), InsertElts.end()); return true; }