diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp @@ -68,6 +68,7 @@ bool updateMovedSmartPointers(CheckerContext &C, const MemRegion *ThisRegion, const MemRegion *OtherSmartPtrRegion) const; void handleBoolConversion(const CallEvent &Call, CheckerContext &C) const; + bool handleOstreamOperator(const CallEvent &Call, CheckerContext &C) const; using SmartPtrMethodHandlerFn = void (SmartPtrModeling::*)(const CallEvent &Call, CheckerContext &) const; @@ -81,6 +82,31 @@ REGISTER_MAP_WITH_PROGRAMSTATE(TrackedRegionMap, const MemRegion *, SVal) +// Checks if RD has name in Names and is in std namespace +static bool hasStdClassWithName(const CXXRecordDecl *RD, + const ArrayRef &Names) { + if (!RD || !RD->getDeclContext()->isStdNamespace()) + return false; + if (RD->getDeclName().isIdentifier()) { + StringRef Name = RD->getName(); + return llvm::any_of(Names, [&Name](StringRef GivenName) -> bool { + return Name == GivenName; + }); + } + return false; +} + +const SmallVector StdPtrNames = {"shared_ptr", "unique_ptr", + "weak_ptr"}; + +static bool isStdSmartPtr(const CXXRecordDecl *RD) { + return hasStdClassWithName(RD, StdPtrNames); +} + +static bool isStdSmartPtr(const Expr *E) { + return isStdSmartPtr(E->getType()->getAsCXXRecordDecl()); +} + // Define the inter-checker API. namespace clang { namespace ento { @@ -89,16 +115,7 @@ const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); if (!MethodDecl || !MethodDecl->getParent()) return false; - - const auto *RecordDecl = MethodDecl->getParent(); - if (!RecordDecl || !RecordDecl->getDeclContext()->isStdNamespace()) - return false; - - if (RecordDecl->getDeclName().isIdentifier()) { - StringRef Name = RecordDecl->getName(); - return Name == "shared_ptr" || Name == "unique_ptr" || Name == "weak_ptr"; - } - return false; + return isStdSmartPtr(MethodDecl->getParent()); } bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion) { @@ -175,9 +192,37 @@ return CD && CD->getConversionType()->isBooleanType(); } +const SmallVector BasicOstreamName = {"basic_ostream"}; + +bool isStdBasicOstream(const Expr *E) { + const auto *RD = E->getType()->getAsCXXRecordDecl(); + return hasStdClassWithName(RD, BasicOstreamName); +} + +bool isStdOstreamOperatorCall(const CallEvent &Call) { + if (Call.getNumArgs() != 2 || + !Call.getDecl()->getDeclContext()->isStdNamespace()) + return false; + const auto *FC = dyn_cast(&Call); + if (!FC) + return false; + const FunctionDecl *FD = FC->getDecl(); + if (!FD->isOverloadedOperator()) + return false; + const OverloadedOperatorKind OOK = FD->getOverloadedOperator(); + if (OOK != clang::OO_LessLess) + return false; + return isStdSmartPtr(Call.getArgExpr(1)) && + isStdBasicOstream(Call.getArgExpr(0)); +} + bool SmartPtrModeling::evalCall(const CallEvent &Call, CheckerContext &C) const { ProgramStateRef State = C.getState(); + + if (isStdOstreamOperatorCall(Call)) + return handleOstreamOperator(Call, C); + if (!smartptr::isStdSmartPtrCall(Call)) return false; @@ -272,6 +317,30 @@ return C.isDifferent(); } +bool SmartPtrModeling::handleOstreamOperator(const CallEvent &Call, + CheckerContext &C) const { + // operator<< does not modify the smart pointer. + // And we don't really have much of modelling of basic_ostream. + // So, we are better off: + // 1) Invalidating the mem-region of the ostream object at hand. + // 2) Setting the SVal of the basic_ostream as the return value. + // Not very satisfying, but it gets the job done, and is better + // than the default handling. :) + + ProgramStateRef State = C.getState(); + const auto StreamVal = Call.getArgSVal(0); + const MemRegion *StreamThisRegion = StreamVal.getAsRegion(); + if (!StreamThisRegion) + return false; + State = + State->invalidateRegions({StreamThisRegion}, Call.getOriginExpr(), + C.blockCount(), C.getLocationContext(), false); + State = + State->BindExpr(Call.getOriginExpr(), C.getLocationContext(), StreamVal); + C.addTransition(State); + return true; +} + void SmartPtrModeling::checkDeadSymbols(SymbolReaper &SymReaper, CheckerContext &C) const { ProgramStateRef State = C.getState();