Index: lib/StaticAnalyzer/Checkers/CMakeLists.txt =================================================================== --- lib/StaticAnalyzer/Checkers/CMakeLists.txt +++ lib/StaticAnalyzer/Checkers/CMakeLists.txt @@ -48,6 +48,7 @@ NSErrorChecker.cpp NoReturnFunctionChecker.cpp NonNullParamChecker.cpp + NullabilityChecker.cpp ObjCAtSyncChecker.cpp ObjCContainersASTChecker.cpp ObjCContainersChecker.cpp Index: lib/StaticAnalyzer/Checkers/Checkers.td =================================================================== --- lib/StaticAnalyzer/Checkers/Checkers.td +++ lib/StaticAnalyzer/Checkers/Checkers.td @@ -128,6 +128,10 @@ HelpText<"Check for division by variable that is later compared against 0. Either the comparison is useless or there is division by zero.">, DescFile<"TestAfterDivZeroChecker.cpp">; +def NullabilityChecker : Checker<"Nullability">, + HelpText<"Warn about nullability missuses">, + DescFile<"NullabilityChecker.cpp">; + } // end "alpha.core" //===----------------------------------------------------------------------===// Index: lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp =================================================================== --- /dev/null +++ lib/StaticAnalyzer/Checkers/NullabilityChecker.cpp @@ -0,0 +1,573 @@ +//== Nullabilityhecker.cpp - Nullability checker ----------------*- C++ -*--==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This checker tries to find nullability violations. The assumption of the +// checker is that, the user is running this checker after all the nullability +// warnings that is emitted by the compiler was fixed. +// +//===----------------------------------------------------------------------===// + +#include "ClangSACheckers.h" +#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h" +#include "clang/StaticAnalyzer/Core/Checker.h" +#include "clang/StaticAnalyzer/Core/CheckerManager.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" + +using namespace clang; +using namespace ento; + +namespace { +enum class Nullability : char { + Contradicted, // Tracked nullability is contradicted by an explicit cast. + Nullable, + Unspecified, // Optimization: Most pointers expected to be unspecified. When + // memory region is not stored in the state, it implicitly means + // unspecified. + Nonnull +}; + +static const char *getNullabilityString(Nullability Nullab) { + switch (Nullab) { + case Nullability::Contradicted: + return "contradicted"; + case Nullability::Nullable: + return "nullable"; + case Nullability::Unspecified: + return "unspecified"; + case Nullability::Nonnull: + return "nonnull"; + } + assert(false); + return ""; +} + +static Nullability getMostNullable(Nullability Lhs, Nullability Rhs) { + return static_cast( + std::min(static_cast(Lhs), static_cast(Rhs))); +} + +enum class ErrorKind { + NilAssignedToNonnull, + NilPassedToNonnull, + NilReturnedToNonnull, + NullableAssignedToNonnull, + NullableReturnedToNonnull, + NullableDereferenced, + NullableAssignedToReference, + NullablePassedToNonnull +}; + +class NullabilityChecker + : public Checker, + check::PostStmt, check::PostObjCMessage, + check::DeadSymbols, check::Event> { + mutable std::unique_ptr BT; + +public: + void checkBind(SVal L, SVal V, const Stmt *S, CheckerContext &C) const; + void checkPostStmt(const ExplicitCastExpr *CE, CheckerContext &C) const; + void checkPreStmt(const ReturnStmt *S, CheckerContext &C) const; + void checkPostObjCMessage(const ObjCMethodCall &M, CheckerContext &C) const; + void checkPreCall(const CallEvent &Call, CheckerContext &C) const; + void checkDeadSymbols(SymbolReaper &SR, CheckerContext &C) const; + void checkEvent(ImplicitNullDerefEvent Event) const; + + class NullabilityBugVisitor + : public BugReporterVisitorImpl { + public: + NullabilityBugVisitor(const MemRegion *M) : Region(M) {} + ~NullabilityBugVisitor() override {} + + void Profile(llvm::FoldingSetNodeID &ID) const override { + static int X = 0; + ID.AddPointer(&X); + ID.AddPointer(Region); + } + + PathDiagnosticPiece *VisitNode(const ExplodedNode *N, + const ExplodedNode *PrevN, + BugReporterContext &BRC, + BugReport &BR) override; + + private: + // The tracked region. + const MemRegion *Region; + }; + + void reportBug(ErrorKind Error, ExplodedNode *N, const MemRegion *Region, + BugReporter &BR) const { + if (!BT) + BT.reset(new BugType(this, "Nullability", "Memory error")); + const char *Msg = nullptr; + switch (Error) { + case ErrorKind::NilAssignedToNonnull: + Msg = "Nullpointer is assigned to nonnull pointer"; + break; + case ErrorKind::NilPassedToNonnull: + Msg = "Nullpointer is passed to a nonnull parameter"; + break; + case ErrorKind::NilReturnedToNonnull: + Msg = "Nullpointer is returned from a nonnull returning function"; + break; + case ErrorKind::NullableAssignedToNonnull: + Msg = "Nullable pointer is assigned to nonnull without a defensive check"; + break; + case ErrorKind::NullableReturnedToNonnull: + Msg = "Nullable pointer is returned from a nonnull returning function " + "without a defensive check"; + break; + case ErrorKind::NullableDereferenced: + Msg = "Nullable pointer is dereferenced without a defensive check"; + break; + case ErrorKind::NullableAssignedToReference: + Msg = "Nullable pointer is assigned to a reference without a defensive " + "check"; + break; + case ErrorKind::NullablePassedToNonnull: + Msg = "Nullable pointer is passed to a nonnull parameter without a " + "defensive check"; + break; + } + assert(Msg); + std::unique_ptr R(new BugReport(*BT, Msg, N)); + if (Region) { + R->markInteresting(Region); + R->addVisitor(llvm::make_unique(Region)); + } + BR.emitReport(std::move(R)); + } +}; + +class NullabilityState { +public: + NullabilityState(Nullability Nullab) : Nullab(Nullab) {} + + Nullability getValue() const { return Nullab; } + + void Profile(llvm::FoldingSetNodeID &ID) const { + ID.AddInteger(static_cast(Nullab)); + } + +private: + Nullability Nullab; +}; + +bool operator==(NullabilityState Lhs, NullabilityState Rhs) { + return Lhs.getValue() == Rhs.getValue(); +} + +} // end anonymous namespace + +REGISTER_MAP_WITH_PROGRAMSTATE(NullabilityMap, const MemRegion *, + NullabilityState) + +PathDiagnosticPiece *NullabilityChecker::NullabilityBugVisitor::VisitNode( + const ExplodedNode *N, const ExplodedNode *PrevN, BugReporterContext &BRC, + BugReport &BR) { + ProgramStateRef state = N->getState(); + ProgramStateRef statePrev = PrevN->getState(); + + const NullabilityState *TrackedNullab = state->get(Region); + const NullabilityState *TrackedNullabPrev = + statePrev->get(Region); + if (!TrackedNullab) + return nullptr; + + if (TrackedNullabPrev && + TrackedNullabPrev->getValue() == TrackedNullab->getValue()) + return nullptr; + + // Retrieve the associated statement. + const Stmt *S = nullptr; + ProgramPoint ProgLoc = N->getLocation(); + if (Optional SP = ProgLoc.getAs()) { + S = SP->getStmt(); + } + + if (!S) + return nullptr; + + std::string InfoText = (llvm::Twine("Nullability '") + + getNullabilityString(TrackedNullab->getValue()) + + "' is infered from this context'") + .str(); + + // Generate the extra diagnostic. + PathDiagnosticLocation Pos(S, BRC.getSourceManager(), + N->getLocationContext()); + return new PathDiagnosticEventPiece(Pos, InfoText, true, nullptr); +} + +static Nullability getNullability(QualType Type) { + const auto *AttrType = Type->getAs(); + if (!AttrType) + return Nullability::Unspecified; + if (AttrType->getAttrKind() == AttributedType::attr_nullable) + return Nullability::Nullable; + else if (AttrType->getAttrKind() == AttributedType::attr_nonnull) + return Nullability::Nonnull; + return Nullability::Unspecified; +} + +void NullabilityChecker::checkDeadSymbols(SymbolReaper &SR, + CheckerContext &C) const { + ProgramStateRef State = C.getState(); + NullabilityMapTy Nullabilities = State->get(); + for (NullabilityMapTy::iterator I = Nullabilities.begin(), + E = Nullabilities.end(); + I != E; ++I) { + if (!SR.isLiveRegion(I->first)) { + State = State->remove(I->first); + } + } +} + +void NullabilityChecker::checkEvent(ImplicitNullDerefEvent Event) const { + SVal DereferencedSVal = Event.Location; + + auto RegionSVal = DereferencedSVal.getAs(); + if (!RegionSVal) + return; + + ProgramStateRef State = Event.SinkNode->getState(); + const MemRegion *Region = RegionSVal->getRegion(); + const NullabilityState *TrackedNullability = + State->get(Region); + + if (!TrackedNullability) { + // Maybe a field or an element is loaded of a nullable pointer. + TrackedNullability = State->get( + Region->getAs()->getSuperRegion()); + if (!TrackedNullability) + return; + } + + Nullability TrackedNullabValue = TrackedNullability->getValue(); + + if (TrackedNullabValue == Nullability::Nullable) { + BugReporter &BR = *Event.BR; + reportBug(ErrorKind::NullableDereferenced, Event.SinkNode, Region, BR); + } +} + +void NullabilityChecker::checkPreStmt(const ReturnStmt *S, + CheckerContext &C) const { + + auto RetExpr = S->getRetValue(); + if (!RetExpr) + return; + + QualType RetExprType = RetExpr->getType(); + // FIXME: What about references? + if (!RetExprType->isPointerType() && !RetExprType->isObjCObjectPointerType()) + return; + + ProgramStateRef State = C.getState(); + SVal RetSVal = State->getSVal(S, C.getLocationContext()); + if (RetSVal.isUndef()) + return; + + AnalysisDeclContext *DeclCtxt = + C.getLocationContext()->getAnalysisDeclContext(); + + const FunctionType *FuncType = DeclCtxt->getDecl()->getFunctionType(); + if (!FuncType) + return; + + QualType ReturnType = FuncType->getReturnType(); + Nullability StaticNullability = getNullability(ReturnType); + + DefinedOrUnknownSVal ReturnValue = RetSVal.castAs(); + + ProgramStateRef StNonNull, StNull; + std::tie(StNonNull, StNull) = State->assume(ReturnValue); + bool IsNotNull = !StNull && StNonNull; + bool IsNull = StNull && !StNonNull; + if (IsNull && StaticNullability == Nullability::Nonnull) { + ExplodedNode *N = C.addTransition(); + reportBug(ErrorKind::NilReturnedToNonnull, N, nullptr, C.getBugReporter()); + return; + } + + auto RetRegionSVal = ReturnValue.getAs(); + if (!RetRegionSVal) + return; + + const MemRegion *Region = RetRegionSVal->getRegion(); + const NullabilityState *TrackedNullability = + State->get(Region); + if (TrackedNullability) { + Nullability TrackedNullabValue = TrackedNullability->getValue(); + if (!IsNotNull && TrackedNullabValue == Nullability::Nullable && + StaticNullability == Nullability::Nonnull) { + ExplodedNode *N = C.addTransition(); + reportBug(ErrorKind::NullableReturnedToNonnull, N, Region, + C.getBugReporter()); + return; + } + } else if (StaticNullability != Nullability::Unspecified) { + State = State->set(Region, StaticNullability); + C.addTransition(State); + } +} + +void NullabilityChecker::checkPreCall(const CallEvent &Call, + CheckerContext &C) const { + const Decl *FD = Call.getDecl(); + if (!FD) + return; + + ProgramStateRef State = C.getState(); + ProgramStateRef OrigState = State; + + unsigned Idx = 0; + for (const ParmVarDecl *Param : Call.parameters()) { + if (Param->isParameterPack()) + break; + + const Expr *ArgExpr = nullptr; + if (Idx < Call.getNumArgs()) + ArgExpr = Call.getArgExpr(Idx); + SVal ArgSVal = Call.getArgSVal(Idx++); + if (ArgSVal.isUndef()) + continue; + + if (!Param->getType()->isPointerType() && + !Param->getType()->isReferenceType() && + !Param->getType()->isObjCObjectPointerType()) { + continue; + } + + ProgramStateRef StNonNull, StNull; + DefinedOrUnknownSVal DefArgSVal = ArgSVal.castAs(); + std::tie(StNonNull, StNull) = State->assume(DefArgSVal); + bool IsNotNull = !StNull && StNonNull; + bool IsNull = StNull && !StNonNull; + + Nullability StaticNullability = getNullability(Param->getType()); + // When the static type of the parameter has no nullability information. The + // static type of the argument might have. + if (StaticNullability == Nullability::Unspecified && ArgExpr) { + StaticNullability = getNullability(ArgExpr->getType()); + } + + if (IsNull && StaticNullability == Nullability::Nonnull) { + ExplodedNode *N = C.addTransition(); + reportBug(ErrorKind::NilPassedToNonnull, N, nullptr, C.getBugReporter()); + return; + } + + auto ArgRegionSVal = ArgSVal.getAs(); + if (!ArgRegionSVal) + continue; + + const MemRegion *Region = ArgRegionSVal->getRegion(); + const NullabilityState *TrackedNullability = + State->get(Region); + + if (TrackedNullability) { + Nullability TrackedNullabValue = TrackedNullability->getValue(); + if (!IsNotNull && TrackedNullabValue == Nullability::Nullable && + StaticNullability == Nullability::Nonnull) { + ExplodedNode *N = C.addTransition(); + reportBug(ErrorKind::NullablePassedToNonnull, N, Region, + C.getBugReporter()); + return; // FIXME: What if multiple parameters should be reported? + } else if (!IsNotNull && TrackedNullabValue == Nullability::Nullable && + Param->getType()->isReferenceType()) { + ExplodedNode *N = C.addTransition(); + reportBug(ErrorKind::NullableAssignedToReference, N, Region, + C.getBugReporter()); + return; + } + } else if (StaticNullability != Nullability::Unspecified) { + State = State->set(Region, StaticNullability); + } + } + if (State != OrigState) + C.addTransition(State); +} + +void NullabilityChecker::checkPostObjCMessage(const ObjCMethodCall &M, + CheckerContext &C) const { + auto Decl = M.getDecl(); + if (!Decl) + return; + QualType RetType = Decl->getReturnType(); + if (!RetType->isPointerType() && !RetType->isObjCObjectPointerType()) + return; + + const ObjCMessageExpr *Message = M.getOriginExpr(); + + ProgramStateRef State = C.getState(); + SVal ResultSVal = M.getReturnValue(); + auto MemRegVal = ResultSVal.getAs(); + if (!MemRegVal) + return; + + Nullability SelfNullability = Nullability::Unspecified; + if (Message->getReceiverKind() == ObjCMessageExpr::SuperClass || + Message->getReceiverKind() == ObjCMessageExpr::SuperInstance) { + SelfNullability = Nullability::Nonnull; + } else { + SVal Receiver = M.getReceiverSVal(); + auto ValueRegionSVal = Receiver.getAs(); + if (ValueRegionSVal) { + const MemRegion *SelfRegion = ValueRegionSVal->getRegion(); + assert(SelfRegion); + + const NullabilityState *TrackedSelfNullability = + State->get(SelfRegion); + if (TrackedSelfNullability) { + SelfNullability = TrackedSelfNullability->getValue(); + } + } + } + + const MemRegion *ReturnRegion = MemRegVal->getRegion(); + assert(ReturnRegion); + + const NullabilityState *TrackedNullability = + State->get(ReturnRegion); + if (TrackedNullability) { + Nullability RetValTracked = TrackedNullability->getValue(); + Nullability NewNullability = + getMostNullable(RetValTracked, SelfNullability); + if (NewNullability != RetValTracked && + NewNullability != Nullability::Unspecified) { + State = State->set(ReturnRegion, NewNullability); + C.addTransition(State); + } + } else { + // Use static type information for return value. + Nullability RetNullability = getNullability(RetType); + RetNullability = getMostNullable(RetNullability, SelfNullability); + if (RetNullability != Nullability::Unspecified) { + State = State->set(ReturnRegion, RetNullability); + C.addTransition(State); + } + } +} + +void NullabilityChecker::checkPostStmt(const ExplicitCastExpr *CE, + CheckerContext &C) const { + QualType OriginType = CE->getSubExpr()->getType(); + QualType DestType = CE->getType(); + if (!OriginType->isPointerType() && !OriginType->isObjCObjectPointerType()) + return; + if (!DestType->isPointerType() && !DestType->isObjCObjectPointerType()) + return; + + Nullability DestNullability = getNullability(DestType); + + if (DestNullability == Nullability::Unspecified) + return; + + ProgramStateRef State = C.getState(); + SVal ExprSVal = State->getSVal(CE, C.getLocationContext()); + SymbolRef Sym = ExprSVal.getAsSymbol(); + if (!Sym) + return; + + const auto *SymRegVal = dyn_cast(Sym); + if (!SymRegVal) + return; + const TypedValueRegion *Region = SymRegVal->getRegion(); + + // When 0 is converted to nonnull mark it as contradicted. + if (DestNullability == Nullability::Nonnull && !ExprSVal.isUndef()) { + ProgramStateRef StNonNull, StNull; + std::tie(StNonNull, StNull) = + State->assume(ExprSVal.castAs()); + if (StNull && !StNonNull) { + State = State->set(Region, Nullability::Contradicted); + C.addTransition(State); + return; + } + } + + const NullabilityState *TrackedNullability = + State->get(Region); + + if (!TrackedNullability) { + State = State->set(Region, DestNullability); + C.addTransition(State); + } else if (TrackedNullability->getValue() != DestNullability) { + // Do not add redundant transitions. + if (TrackedNullability->getValue() == Nullability::Contradicted) + return; + State = State->set(Region, Nullability::Contradicted); + C.addTransition(State); + } +} + +void NullabilityChecker::checkBind(SVal L, SVal V, const Stmt *S, + CheckerContext &C) const { + const MemRegion *MR = L.getAsRegion(); + const TypedValueRegion *TVR = dyn_cast_or_null(MR); + if (!TVR) + return; + + QualType LocType = TVR->getValueType(); + if (!LocType->isPointerType() && !LocType->isReferenceType()) + return; + + Nullability LocNullability = getNullability(LocType); + + ProgramStateRef State = C.getState(); + ProgramStateRef StNonNull, StNull; + std::tie(StNonNull, StNull) = State->assume(V.castAs()); + bool RhsIsNull = !StNonNull && StNull; + bool RhsIsNotNull = StNonNull && !StNull; + + // The null pointer is loaded to a reference is handled in another checker. + if (RhsIsNull && LocNullability == Nullability::Nonnull) { + ExplodedNode *N = C.addTransition(); + reportBug(ErrorKind::NilAssignedToNonnull, N, nullptr, C.getBugReporter()); + return; + } + + auto ValueRegionSVal = V.getAs(); + if (!ValueRegionSVal) + return; + + const MemRegion *ValueRegion = ValueRegionSVal->getRegion(); + assert(ValueRegion); + + Nullability ValNullability = Nullability::Unspecified; + if (SymbolRef Sym = V.getAsSymbol()) + ValNullability = getNullability(Sym->getType()); + + const NullabilityState *TrackedNullability = + State->get(ValueRegion); + + if (TrackedNullability) { + ValNullability = TrackedNullability->getValue(); + if (!RhsIsNotNull && ValNullability == Nullability::Nullable) { + if (LocNullability == Nullability::Nonnull) { + ExplodedNode *N = C.addTransition(); + reportBug(ErrorKind::NullableAssignedToNonnull, N, ValueRegion, + C.getBugReporter()); + } + } + } else if (ValNullability != Nullability::Unspecified) { + // Trust the static information of the value more than the static + // information on the location. + State = State->set(ValueRegion, ValNullability); + C.addTransition(State); + } else if (LocNullability != Nullability::Unspecified) { + State = State->set(ValueRegion, LocNullability); + C.addTransition(State); + } +} + +void ento::registerNullabilityChecker(CheckerManager &mgr) { + mgr.registerChecker(); +} Index: test/Analysis/nullability.mm =================================================================== --- /dev/null +++ test/Analysis/nullability.mm @@ -0,0 +1,115 @@ +// RUN: %clang_cc1 -analyze -analyzer-checker=core,alpha.core.Nullability -verify %s + +#define nil 0 + +@protocol NSObject ++ (id)alloc; +- (id)init; +@end + +@protocol NSCopying +@end + +__attribute__((objc_root_class)) +@interface NSObject +@end + +@interface TestObject : NSObject +- (int * _Nonnull)returnsNonnull; +- (int * _Nullable)returnsNullable; +- (int *) returnsUnspecified; +- (void)takesNonnull:(int * _Nonnull)p; +- (void)takesNullable:(int * _Nullable)p; +- (void)takesUnspecified:(int *)p; +@end + +TestObject *getUnspecifiedTestObject(); +TestObject * _Nonnull getNonnullTestObject(); +TestObject * _Nullable getNullableTestObject(); + +int getRandom(); + +typedef struct Dummy { + int val; +} Dummy; + +void takesNullable(Dummy * _Nullable); +void takesNonnull(Dummy * _Nonnull); +void takesUnspecified(Dummy *); + +Dummy * _Nullable returnsNullable(); +Dummy * _Nonnull returnsNonnull(); +Dummy * returnsUnspecified(); +int * _Nullable returnsNullableInt(); + +template +T* eraseNullab(T* p) { + return p; +} + +void testBasicRules() { + Dummy *p = returnsNullable(); + int *ptr = returnsNullableInt(); + // Make every dereference a different path to avoid nonnull assumptions. + switch(getRandom()) { + case 0: { Dummy &r = *p; } break; // expected-warning {{}} + case 1: { int b = p->val; } break; // expected-warning {{}} + case 2: { int stuff = *ptr; } break; // expected-warning {{}} + case 3: takesNonnull(p); break; // expected-warning {{}} + default: { Dummy d = *p; } break; // expected-warning {{}} + } + if (p) { + takesNonnull(p); + if (getRandom()) { + Dummy &r = *p; + } else { + int b = p->val; + } + } + Dummy *q = 0; + takesNullable(q); + takesNonnull(q); // expected-warning {{}} + Dummy a; + Dummy * _Nonnull nonnull = &a; + nonnull = q; // expected-warning {{}} + q = &a; + takesNullable(q); + takesNonnull(q); +} + +void testArgumentTracking(Dummy * _Nonnull nonnull, Dummy * _Nullable nullable) { + Dummy *p = nullable; + nonnull = p; // expected-warning {{}} + p = 0; + Dummy *q = nonnull; + q = p; +} + +Dummy * _Nonnull testNullableReturn(Dummy * _Nullable a) { + Dummy *p = a; + return p; // expected-warning {{}} +} + +Dummy * _Nonnull testNullReturn() { + Dummy *p = 0; + return p; // expected-warning {{}} +} + +void testObjCMessageResultNullability() { + // The expected result: the most nullable of self and method return type. + TestObject *o = getUnspecifiedTestObject(); + int *shouldBeNullable = [eraseNullab(getNullableTestObject()) returnsNonnull]; + [o takesNonnull: shouldBeNullable]; // expected-warning {{}} + shouldBeNullable = [eraseNullab(getNullableTestObject()) returnsUnspecified]; + [o takesNonnull: shouldBeNullable]; // expected-warning {{}} + shouldBeNullable = [eraseNullab(getNullableTestObject()) returnsNullable]; + [o takesNonnull: shouldBeNullable]; // expected-warning {{}} + shouldBeNullable = [eraseNullab(getNonnullTestObject()) returnsNullable]; + [o takesNonnull: shouldBeNullable]; // expected-warning {{}} + shouldBeNullable = [eraseNullab(getUnspecifiedTestObject()) returnsNullable]; + [o takesNonnull: shouldBeNullable]; // expected-warning {{}} + shouldBeNullable = [eraseNullab(getNullableTestObject()) returnsNullable]; + [o takesNonnull: shouldBeNullable]; // expected-warning {{}} + int * shouldBeNonnull = [eraseNullab(getNonnullTestObject()) returnsNonnull]; + [o takesNonnull: shouldBeNonnull]; +}