diff --git a/llvm/lib/Analysis/StackSafetyAnalysis.cpp b/llvm/lib/Analysis/StackSafetyAnalysis.cpp --- a/llvm/lib/Analysis/StackSafetyAnalysis.cpp +++ b/llvm/lib/Analysis/StackSafetyAnalysis.cpp @@ -414,6 +414,11 @@ } const auto &CB = cast(*I); + if (CB.getReturnedArgOperand() == V) { + if (Visited.insert(I).second) + WorkList.push_back(cast(I)); + } + if (!CB.isArgOperand(&UI)) { US.addRange(I, UnknownRange); break; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -4533,6 +4533,13 @@ if (OffsetZero && !GEP->hasAllZeroIndices()) return nullptr; AddWork(GEP->getPointerOperand()); + } else if (CallBase *CB = dyn_cast(V)) { + Value *Returned = CB->getReturnedArgOperand(); + if (Returned) { + AddWork(Returned); + } else { + return nullptr; + } } else { return nullptr; } diff --git a/llvm/test/Instrumentation/HWAddressSanitizer/stack-safety-analysis.ll b/llvm/test/Instrumentation/HWAddressSanitizer/stack-safety-analysis.ll --- a/llvm/test/Instrumentation/HWAddressSanitizer/stack-safety-analysis.ll +++ b/llvm/test/Instrumentation/HWAddressSanitizer/stack-safety-analysis.ll @@ -136,6 +136,24 @@ ret i32 0 } +; Check an alloca with out of range GEP to ensure it gets a tag. +define i32 @test_retptr(i32* %a) sanitize_hwaddress { +entry: + ; CHECK-LABEL: @test_retptr + ; NOSAFETY: call {{.*}}__hwasan_generate_tag + ; NOSAFETY: call {{.*}}__hwasan_store + ; SAFETY: call {{.*}}__hwasan_generate_tag + ; SAFETY-NOT: call {{.*}}__hwasan_store + ; NOSTACK-NOT: call {{.*}}__hwasan_generate_tag + ; NOSTACK-NOT: call {{.*}}__hwasan_store + %buf.sroa.0 = alloca i8, align 4 + call void @llvm.lifetime.start.p0i8(i64 1, i8* nonnull %buf.sroa.0) + %ptr = call i8* @retptr(i8* %buf.sroa.0) + store volatile i8 0, i8* %ptr, align 4, !tbaa !8 + call void @llvm.lifetime.end.p0i8(i64 1, i8* nonnull %buf.sroa.0) + ret i32 0 +} + ; Function Attrs: argmemonly mustprogress nofree nosync nounwind willreturn declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) @@ -145,6 +163,7 @@ declare void @use(i8* nocapture) declare i32 @getoffset() declare i8* @getptr(i8* nocapture) +declare i8* @retptr(i8* returned) !8 = !{!9, !9, i64 0} !9 = !{!"omnipotent char", !10, i64 0}