diff --git a/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h b/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h --- a/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h +++ b/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h @@ -12,6 +12,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_MEMORYTAGGINGSUPPORT_H #define LLVM_TRANSFORMS_UTILS_MEMORYTAGGINGSUPPORT_H +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/Dominators.h" @@ -68,6 +69,9 @@ bool isStandardLifetime(const SmallVectorImpl &LifetimeStart, const SmallVectorImpl &LifetimeEnd, const DominatorTree *DT, size_t MaxLifetimes); + +Instruction *getUntagLocationIfFunctionExit(Instruction &Inst); + } // namespace llvm #endif diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -1511,14 +1511,9 @@ } } - if (isa(Inst)) { - if (CallInst *CI = Inst.getParent()->getTerminatingMustTailCall()) - RetVec.push_back(CI); - else - RetVec.push_back(&Inst); - } else if (isa(Inst)) { - RetVec.push_back(&Inst); - } + Instruction *ExitUntag = getUntagLocationIfFunctionExit(Inst); + if (ExitUntag) + RetVec.push_back(ExitUntag); if (auto *DVI = dyn_cast(&Inst)) { for (Value *V : DVI->location_ops()) { diff --git a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp --- a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp +++ b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp @@ -42,4 +42,16 @@ (LifetimeEnd.size() > 0 && !maybeReachableFromEachOther(LifetimeEnd, DT, MaxLifetimes))); } + +Instruction *getUntagLocationIfFunctionExit(Instruction &Inst) { + if (isa(Inst)) { + if (CallInst *CI = Inst.getParent()->getTerminatingMustTailCall()) + return CI; + return &Inst; + } + if (isa(Inst)) { + return &Inst; + } + return nullptr; +} } // namespace llvm