diff --git a/clang/lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp --- a/clang/lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp @@ -26,12 +26,13 @@ #include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h" +#include "clang/Analysis/AnyCall.h" #include "clang/StaticAnalyzer/Core/BugReporter/BugType.h" #include "clang/StaticAnalyzer/Core/Checker.h" #include "clang/StaticAnalyzer/Core/CheckerManager.h" -#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerHelpers.h" -#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerHelpers.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Path.h" @@ -81,7 +82,8 @@ : public Checker, check::PostCall, check::PostStmt, check::PostObjCMessage, check::DeadSymbols, eval::Assume, - check::Location, check::Event> { + check::Location, check::Event, + check::BeginFunction> { public: // If true, the checker will not diagnose nullabilility issues for calls @@ -102,6 +104,7 @@ void checkEvent(ImplicitNullDerefEvent Event) const; void checkLocation(SVal Location, bool IsLoad, const Stmt *S, CheckerContext &C) const; + void checkBeginFunction(CheckerContext &Ctx) const; ProgramStateRef evalAssume(ProgramStateRef State, SVal Cond, bool Assumption) const; @@ -563,6 +566,37 @@ } } +void NullabilityChecker::checkBeginFunction(CheckerContext &C) const { + if (!C.inTopFrame()) + return; + + const LocationContext *LCtx = C.getLocationContext(); + auto AbstractCall = AnyCall::forDecl(LCtx->getDecl()); + if (!AbstractCall || AbstractCall->parameters().empty()) + return; + + ProgramStateRef State = C.getState(); + for (const ParmVarDecl *Param : AbstractCall->parameters()) { + if (!isValidPointerType(Param->getType())) + continue; + + Nullability RequiredNullability = + getNullabilityAnnotation(Param->getType()); + if (RequiredNullability != Nullability::Nullable) + continue; + + const VarRegion *ParamRegion = State->getRegion(Param, LCtx); + const MemRegion *ParamPointeeRegion = + State->getSVal(ParamRegion).getAsRegion(); + if (!ParamPointeeRegion) + continue; + + State = C.getState()->set( + ParamPointeeRegion, NullabilityState(RequiredNullability)); + } + C.addTransition(State); +} + // Whenever we see a load from a typed memory region that's been annotated as // 'nonnull', we want to trust the user on that and assume that it is is indeed // non-null. diff --git a/clang/test/Analysis/nullability.mm b/clang/test/Analysis/nullability.mm --- a/clang/test/Analysis/nullability.mm +++ b/clang/test/Analysis/nullability.mm @@ -145,6 +145,17 @@ } } +void testArgumentTrackingDirectly(Dummy *_Nonnull nonnull, Dummy *_Nullable nullable) { + switch(getRandom()) { + case 1: testMultiParamChecking(nonnull, nullable, nonnull); break; + case 2: testMultiParamChecking(nonnull, nonnull, nonnull); break; + case 3: testMultiParamChecking(nonnull, nullable, nullable); break; // expected-warning {{Nullable pointer is passed to a callee that requires a non-null 3rd parameter}} + case 4: testMultiParamChecking(nullable, nullable, nonnull); // expected-warning {{Nullable pointer is passed to a callee that requires a non-null 1st parameter}} + case 5: testMultiParamChecking(nullable, nullable, nullable); // expected-warning {{Nullable pointer is passed to a callee that requires a non-null 1st parameter}} + case 6: testMultiParamChecking((Dummy *_Nonnull)0, nullable, nonnull); break; + } +} + Dummy *_Nonnull testNullableReturn(Dummy *_Nullable a) { Dummy *p = a; return p; // expected-warning {{Nullable pointer is returned from a function that is expected to return a non-null value}}