diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -629,7 +629,7 @@ /// the stored value. Otherwise, the size is the width of the largest loaded /// value reaching V. This method is used by the vectorizer to calculate /// vectorization factors. - unsigned getVectorElementSize(Value *V) const; + unsigned getVectorElementSize(Value *V); /// Compute the minimum type sizes required to represent the entries in a /// vectorizable tree. @@ -1715,6 +1715,9 @@ /// Maps a specific scalar to its tree entry. SmallDenseMap ScalarToTreeEntry; + /// Maps a value to the proposed vectorizable size. + SmallDenseMap InstrElementSize; + /// A list of scalars that we found that we need to keep as scalars. ValueSet MustGather; @@ -4785,6 +4788,7 @@ } Builder.ClearInsertionPoint(); + InstrElementSize.clear(); return VectorizableTree[0]->VectorizedValue; } @@ -5321,12 +5325,16 @@ BS->ScheduleStart = nullptr; } -unsigned BoUpSLP::getVectorElementSize(Value *V) const { +unsigned BoUpSLP::getVectorElementSize(Value *V) { // If V is a store, just return the width of the stored value without // traversing the expression tree. This is the common case. if (auto *Store = dyn_cast(V)) return DL->getTypeSizeInBits(Store->getValueOperand()->getType()); + auto E = InstrElementSize.find(V); + if (E != InstrElementSize.end()) + return E->second; + // If V is not a store, we can traverse the expression tree to find loads // that feed it. The type of the loaded value may indicate a more suitable // width than V's type. We want to base the vector element size on the width @@ -5372,13 +5380,17 @@ FoundUnknownInst = true; } + int Width = MaxWidth; // If we didn't encounter a memory access in the expression tree, or if we - // gave up for some reason, just return the width of V. + // gave up for some reason, just return the width of V. Otherwise, return the + // maximum width we found. if (!MaxWidth || FoundUnknownInst) - return DL->getTypeSizeInBits(V->getType()); + Width = DL->getTypeSizeInBits(V->getType()); + + for (Instruction *I : Visited) + InstrElementSize[I] = Width; - // Otherwise, return the maximum width we found. - return MaxWidth; + return Width; } // Determine if a value V in a vectorizable expression Expr can be demoted to a