diff --git a/llvm/include/llvm/Analysis/MemoryProfileInfo.h b/llvm/include/llvm/Analysis/MemoryProfileInfo.h --- a/llvm/include/llvm/Analysis/MemoryProfileInfo.h +++ b/llvm/include/llvm/Analysis/MemoryProfileInfo.h @@ -128,6 +128,7 @@ CallStackIterator begin() const; CallStackIterator end() const { return CallStackIterator(N, /*End*/ true); } CallStackIterator beginAfterSharedPrefix(CallStack &Other); + uint64_t back() const; private: const NodeT *N = nullptr; @@ -137,9 +138,8 @@ CallStack::CallStackIterator::CallStackIterator( const NodeT *N, bool End) : N(N) { - if (!N) - return; - Iter = End ? N->StackIdIndices.end() : N->StackIdIndices.begin(); + Iter = + N ? (End ? N->StackIdIndices.end() : N->StackIdIndices.begin()) : nullptr; } template @@ -148,6 +148,12 @@ return *Iter; } +template +uint64_t CallStack::back() const { + assert(N); + return N->StackIdIndices.back(); +} + template typename CallStack::CallStackIterator CallStack::begin() const { @@ -170,6 +176,7 @@ const MDNode *N, bool End); template <> uint64_t CallStack::CallStackIterator::operator*(); +template <> uint64_t CallStack::back() const; } // end namespace memprof } // end namespace llvm diff --git a/llvm/lib/Analysis/MemoryProfileInfo.cpp b/llvm/lib/Analysis/MemoryProfileInfo.cpp --- a/llvm/lib/Analysis/MemoryProfileInfo.cpp +++ b/llvm/lib/Analysis/MemoryProfileInfo.cpp @@ -242,3 +242,9 @@ assert(StackIdCInt); return StackIdCInt->getZExtValue(); } + +template <> uint64_t CallStack::back() const { + assert(N); + return mdconst::dyn_extract(N->operands().back()) + ->getZExtValue(); +} diff --git a/llvm/unittests/Analysis/MemoryProfileInfoTest.cpp b/llvm/unittests/Analysis/MemoryProfileInfoTest.cpp --- a/llvm/unittests/Analysis/MemoryProfileInfoTest.cpp +++ b/llvm/unittests/Analysis/MemoryProfileInfoTest.cpp @@ -401,6 +401,8 @@ auto *MIBMD = cast(MIBOp); MDNode *StackNode = getMIBStackNode(MIBMD); CallStack StackContext(StackNode); + uint64_t ExpectedBack = First ? 4 : 5; + EXPECT_EQ(StackContext.back(), ExpectedBack); std::vector StackIds; for (auto ContextIter = StackContext.beginAfterSharedPrefix(InstCallsite); ContextIter != StackContext.end(); ++ContextIter) @@ -450,6 +452,8 @@ for (auto &MIB : AI.MIBs) { CallStack::const_iterator> StackContext( &MIB); + uint64_t ExpectedBack = First ? 4 : 5; + EXPECT_EQ(Index->getStackIdAtIndex(StackContext.back()), ExpectedBack); std::vector StackIds; for (auto StackIdIndex : StackContext) StackIds.push_back(Index->getStackIdAtIndex(StackIdIndex));