diff --git a/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h b/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h --- a/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h +++ b/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h @@ -15,6 +15,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Support/Alignment.h" namespace llvm { @@ -33,14 +34,15 @@ // the caller should remove Ends to ensure that work done at the other // exits does not happen outside of the lifetime. bool forAllReachableExits(const DominatorTree &DT, const PostDominatorTree &PDT, - const Instruction *Start, + const LoopInfo &LI, const Instruction *Start, const SmallVectorImpl &Ends, const SmallVectorImpl &RetVec, llvm::function_ref Callback); bool isStandardLifetime(const SmallVectorImpl &LifetimeStart, const SmallVectorImpl &LifetimeEnd, - const DominatorTree *DT, size_t MaxLifetimes); + const DominatorTree *DT, const LoopInfo *LI, + size_t MaxLifetimes); Instruction *getUntagLocationIfFunctionExit(Instruction &Inst); diff --git a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp --- a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp +++ b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp @@ -58,6 +58,7 @@ #include "llvm/Transforms/Utils/MemoryTaggingSupport.h" #include #include +#include #include using namespace llvm; @@ -523,6 +524,15 @@ PDT = DeletePDT.get(); } + std::unique_ptr DeleteLI; + LoopInfo *LI = nullptr; + if (auto *LIWP = getAnalysisIfAvailable()) { + LI = &LIWP->getLoopInfo(); + } else { + DeleteLI = std::make_unique(*DT); + LI = DeleteLI.get(); + } + SetTagFunc = Intrinsic::getDeclaration(F->getParent(), Intrinsic::aarch64_settag); @@ -555,7 +565,7 @@ // statement if return_twice functions are called. bool StandardLifetime = SInfo.UnrecognizedLifetimes.empty() && - memtag::isStandardLifetime(Info.LifetimeStart, Info.LifetimeEnd, DT, + memtag::isStandardLifetime(Info.LifetimeStart, Info.LifetimeEnd, DT, LI, ClMaxLifetimes) && !SInfo.CallsReturnTwice; if (StandardLifetime) { @@ -567,7 +577,7 @@ auto TagEnd = [&](Instruction *Node) { untagAlloca(AI, Node, Size); }; if (!DT || !PDT || - !memtag::forAllReachableExits(*DT, *PDT, Start, Info.LifetimeEnd, + !memtag::forAllReachableExits(*DT, *PDT, *LI, Start, Info.LifetimeEnd, SInfo.RetVec, TagEnd)) { for (auto *End : Info.LifetimeEnd) End->eraseFromParent(); diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -292,7 +292,8 @@ Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong); bool instrumentStack(memtag::StackInfo &Info, Value *StackTag, - const DominatorTree &DT, const PostDominatorTree &PDT); + const DominatorTree &DT, const PostDominatorTree &PDT, + const LoopInfo &LI); Value *readRegister(IRBuilder<> &IRB, StringRef Name); bool instrumentLandingPads(SmallVectorImpl &RetVec); Value *getNextTagWithCall(IRBuilder<> &IRB); @@ -1217,7 +1218,8 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, Value *StackTag, const DominatorTree &DT, - const PostDominatorTree &PDT) { + const PostDominatorTree &PDT, + const LoopInfo &LI) { // Ideally, we want to calculate tagged stack base pointer, and rewrite all // alloca addresses using that. Unfortunately, offsets are not known yet // (unless we use ASan-style mega-alloca). Instead we keep the base tag in a @@ -1294,13 +1296,13 @@ bool StandardLifetime = SInfo.UnrecognizedLifetimes.empty() && memtag::isStandardLifetime(Info.LifetimeStart, Info.LifetimeEnd, &DT, - ClMaxLifetimes) && + &LI, ClMaxLifetimes) && !SInfo.CallsReturnTwice; if (DetectUseAfterScope && StandardLifetime) { IntrinsicInst *Start = Info.LifetimeStart[0]; IRB.SetInsertPoint(Start->getNextNode()); tagAlloca(IRB, AI, Tag, Size); - if (!memtag::forAllReachableExits(DT, PDT, Start, Info.LifetimeEnd, + if (!memtag::forAllReachableExits(DT, PDT, LI, Start, Info.LifetimeEnd, SInfo.RetVec, TagEnd)) { for (auto *End : Info.LifetimeEnd) End->eraseFromParent(); @@ -1405,9 +1407,10 @@ if (!SInfo.AllocasToInstrument.empty()) { const DominatorTree &DT = FAM.getResult(F); const PostDominatorTree &PDT = FAM.getResult(F); + const LoopInfo &LI = FAM.getResult(F); Value *StackTag = ClGenerateTagsWithCalls ? nullptr : getStackBaseTag(EntryIRB); - instrumentStack(SInfo, StackTag, DT, PDT); + instrumentStack(SInfo, StackTag, DT, PDT, LI); } // If we split the entry block, move any allocas that were originally in the diff --git a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp --- a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp +++ b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp @@ -22,7 +22,8 @@ namespace memtag { namespace { bool maybeReachableFromEachOther(const SmallVectorImpl &Insts, - const DominatorTree *DT, size_t MaxLifetimes) { + const DominatorTree *DT, const LoopInfo *LI, + size_t MaxLifetimes) { // If we have too many lifetime ends, give up, as the algorithm below is N^2. if (Insts.size() > MaxLifetimes) return true; @@ -30,7 +31,7 @@ for (size_t J = 0; J < Insts.size(); ++J) { if (I == J) continue; - if (isPotentiallyReachable(Insts[I], Insts[J], nullptr, DT)) + if (isPotentiallyReachable(Insts[I], Insts[J], nullptr, DT, LI)) return true; } } @@ -39,7 +40,7 @@ } // namespace bool forAllReachableExits(const DominatorTree &DT, const PostDominatorTree &PDT, - const Instruction *Start, + const LoopInfo &LI, const Instruction *Start, const SmallVectorImpl &Ends, const SmallVectorImpl &RetVec, llvm::function_ref Callback) { @@ -54,7 +55,7 @@ SmallVector ReachableRetVec; unsigned NumCoveredExits = 0; for (auto *RI : RetVec) { - if (!isPotentiallyReachable(Start, RI, nullptr, &DT)) + if (!isPotentiallyReachable(Start, RI, nullptr, &DT, &LI)) continue; ReachableRetVec.push_back(RI); // If there is an end in the same basic block as the return, we know for @@ -62,7 +63,7 @@ // is a way to reach the RI from the start of the lifetime without passing // through an end. if (EndBlocks.count(RI->getParent()) > 0 || - !isPotentiallyReachable(Start, RI, &EndBlocks, &DT)) { + !isPotentiallyReachable(Start, RI, &EndBlocks, &DT, &LI)) { ++NumCoveredExits; } } @@ -83,14 +84,15 @@ bool isStandardLifetime(const SmallVectorImpl &LifetimeStart, const SmallVectorImpl &LifetimeEnd, - const DominatorTree *DT, size_t MaxLifetimes) { + const DominatorTree *DT, const LoopInfo *LI, + size_t MaxLifetimes) { // An alloca that has exactly one start and end in every possible execution. // If it has multiple ends, they have to be unreachable from each other, so // at most one of them is actually used for each execution of the function. return LifetimeStart.size() == 1 && (LifetimeEnd.size() == 1 || (LifetimeEnd.size() > 0 && - !maybeReachableFromEachOther(LifetimeEnd, DT, MaxLifetimes))); + !maybeReachableFromEachOther(LifetimeEnd, DT, LI, MaxLifetimes))); } Instruction *getUntagLocationIfFunctionExit(Instruction &Inst) { diff --git a/llvm/test/CodeGen/AArch64/stack-tagging-loop.ll b/llvm/test/CodeGen/AArch64/stack-tagging-loop.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/stack-tagging-loop.ll @@ -0,0 +1,60 @@ +; We set a low dom-tree-reachability-max-bbs-to-explore to check whether the +; loop analysis is working. Without skipping over the loop, we would need more +; than 4 BB to reach end from entry. + +; RUN: opt -S -dom-tree-reachability-max-bbs-to-explore=4 -aarch64-stack-tagging %s -o - | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64" + +define dso_local void @foo(i1 %x, i32 %n) sanitize_memtag { +entry: + %c = alloca [1024 x i8], align 1 + call void @llvm.lifetime.start.p0(i64 1024, ptr nonnull %c) + %cmp2.not = icmp eq i32 %n, 0 + br i1 %x, label %entry2, label %noloop + +entry2: + br i1 %cmp2.not, label %for.cond.cleanup, label %for.body + +for.cond.cleanup: ; preds = %for.body, %entry +; CHECK-LABEL: for.cond.cleanup: +; CHECK: call{{.*}}settag +; CHECK: call{{.*}}lifetime.end + call void @llvm.lifetime.end.p0(i64 1024, ptr nonnull %c) + call void @bar(ptr noundef nonnull inttoptr (i64 120 to ptr)) + br label %end + +for.body: ; preds = %entry, %for.body + %i.03 = phi i32 [ %inc, %for.body2 ], [ 0, %entry2 ] + call void @bar(ptr noundef nonnull %c) #3 + br label %for.body2 + +for.body2: + %inc = add nuw nsw i32 %i.03, 1 + %cmp = icmp ult i32 %inc, %n + br i1 %cmp, label %for.body, label %for.cond.cleanup, !llvm.loop !13 + +noloop: +; CHECK-LABEL: noloop: +; CHECK: call{{.*}}settag +; CHECK: call{{.*}}lifetime.end + call void @llvm.lifetime.end.p0(i64 1024, ptr nonnull %c) + br label %end + +end: +; CHECK-LABEL: end: +; CHECK-NOT: call{{.*}}settag + ret void +} + +; Function Attrs: argmemonly mustprogress nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #0 +declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #0 + +declare dso_local void @bar(ptr noundef) + +attributes #0 = { argmemonly mustprogress nocallback nofree nosync nounwind willreturn } + +!13 = distinct !{!13, !14} +!14 = !{!"llvm.loop.mustprogress"}