diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h b/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h @@ -22,6 +22,8 @@ /// Returns true if the event call is on smart pointer. bool isStdSmartPtrCall(const CallEvent &Call); +bool isStdSmartPtr(const CXXRecordDecl *RD); +bool isStdSmartPtr(const Expr *E); /// Returns whether the smart pointer is null or not. bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion); diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp @@ -68,6 +68,7 @@ bool updateMovedSmartPointers(CheckerContext &C, const MemRegion *ThisRegion, const MemRegion *OtherSmartPtrRegion) const; void handleBoolConversion(const CallEvent &Call, CheckerContext &C) const; + bool handleComparisionOp(const CallEvent &Call, CheckerContext &C) const; using SmartPtrMethodHandlerFn = void (SmartPtrModeling::*)(const CallEvent &Call, CheckerContext &) const; @@ -89,18 +90,24 @@ const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); if (!MethodDecl || !MethodDecl->getParent()) return false; + return isStdSmartPtr(MethodDecl->getParent()); +} - const auto *RecordDecl = MethodDecl->getParent(); - if (!RecordDecl || !RecordDecl->getDeclContext()->isStdNamespace()) +bool isStdSmartPtr(const CXXRecordDecl *RD) { + if (!RD || !RD->getDeclContext()->isStdNamespace()) return false; - if (RecordDecl->getDeclName().isIdentifier()) { - StringRef Name = RecordDecl->getName(); + if (RD->getDeclName().isIdentifier()) { + StringRef Name = RD->getName(); return Name == "shared_ptr" || Name == "unique_ptr" || Name == "weak_ptr"; } return false; } +bool isStdSmartPtr(const Expr *E) { + return isStdSmartPtr(E->getType()->getAsCXXRecordDecl()); +} + bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion) { const auto *InnerPointVal = State->get(ThisRegion); return InnerPointVal && @@ -135,18 +142,11 @@ return State; } -// Helper method to get the inner pointer type of specialized smart pointer -// Returns empty type if not found valid inner pointer type. -static QualType getInnerPointerType(const CallEvent &Call, CheckerContext &C) { - const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); - if (!MethodDecl || !MethodDecl->getParent()) +static QualType getInnerPointerType(CheckerContext C, const CXXRecordDecl *RD) { + if (!RD || !RD->isInStdNamespace()) return {}; - const auto *RecordDecl = MethodDecl->getParent(); - if (!RecordDecl || !RecordDecl->isInStdNamespace()) - return {}; - - const auto *TSD = dyn_cast(RecordDecl); + const auto *TSD = dyn_cast(RD); if (!TSD) return {}; @@ -157,6 +157,17 @@ return C.getASTContext().getPointerType(InnerValueType.getCanonicalType()); } +// Helper method to get the inner pointer type of specialized smart pointer +// Returns empty type if not found valid inner pointer type. +static QualType getInnerPointerType(const CallEvent &Call, CheckerContext &C) { + const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); + if (!MethodDecl || !MethodDecl->getParent()) + return {}; + + const auto *RecordDecl = MethodDecl->getParent(); + return getInnerPointerType(C, RecordDecl); +} + // Helper method to pretty print region and avoid extra spacing. static void checkAndPrettyPrintRegion(llvm::raw_ostream &OS, const MemRegion *Region) { @@ -178,6 +189,16 @@ bool SmartPtrModeling::evalCall(const CallEvent &Call, CheckerContext &C) const { ProgramStateRef State = C.getState(); + + // If any one of the arg is a unique_ptr, then + // we can try this function + if (Call.getNumArgs() == 2 && + Call.getDecl()->getDeclContext()->isStdNamespace()) + if (smartptr::isStdSmartPtr(Call.getArgExpr(0)) || + smartptr::isStdSmartPtr(Call.getArgExpr(1))) + if (handleComparisionOp(Call, C)) + return true; + if (!smartptr::isStdSmartPtrCall(Call)) return false; @@ -272,6 +293,87 @@ return C.isDifferent(); } +bool SmartPtrModeling::handleComparisionOp(const CallEvent &Call, + CheckerContext &C) const { + const auto *FC = dyn_cast(&Call); + if (!FC) + return false; + const FunctionDecl *FD = FC->getDecl(); + if (!FD->isOverloadedOperator()) + return false; + const OverloadedOperatorKind OOK = FD->getOverloadedOperator(); + if (!(OOK == OO_EqualEqual || OOK == OO_ExclaimEqual || OOK == OO_Less || + OOK == OO_LessEqual || OOK == OO_Greater || OOK == OO_GreaterEqual || + OOK == OO_Spaceship)) + return false; + + Call.dump(); + // There are some special cases about which we can infer about + // the resulting answer. + // For reference, there is a discussion at https://reviews.llvm.org/D104616. + // Also, the cppreference page is good to look at + // https://en.cppreference.com/w/cpp/memory/unique_ptr/operator_cmp. + + ProgramStateRef State = C.getState(); + + auto retrieveOrConjureInnerPtrVal = [&C, &State](const Expr *E, + SVal S) -> SVal { + if (S.isZeroConstant()) { + return C.getSValBuilder().makeZeroVal(E->getType()); + } + const MemRegion *Reg = S.getAsRegion(); + assert(Reg && + "this pointer of std::unique_ptr should be obtainable as MemRegion"); + const auto *Ptr = State->get(Reg); + if (Ptr) + return *Ptr; + return C.getSValBuilder().conjureSymbolVal( + E, C.getLocationContext(), + getInnerPointerType(C, E->getType()->getAsCXXRecordDecl()), + C.blockCount()); + }; + + BinaryOperatorKind BOK; + switch (OOK) { + case OO_EqualEqual: + BOK = BO_EQ; + break; + case OO_ExclaimEqual: + BOK = BO_NE; + break; + case OO_GreaterEqual: + BOK = BO_GE; + break; + case OO_LessEqual: + BOK = BO_LE; + break; + case OO_Less: + BOK = BO_LT; + break; + case OO_Greater: + BOK = BO_GT; + break; + case OO_Spaceship: + BOK = BO_Cmp; + break; + default: + llvm_unreachable("unexpected overloaded operator kind"); + } + + SVal First = Call.getArgSVal(0); + SVal Second = Call.getArgSVal(1); + const auto *FirstExpr = Call.getArgExpr(0); + const auto *SecondExpr = Call.getArgExpr(1); + + SVal FirstPtrVal = retrieveOrConjureInnerPtrVal(FirstExpr, First); + SVal SecondPtrVal = retrieveOrConjureInnerPtrVal(SecondExpr, Second); + auto RetVal = C.getSValBuilder().evalBinOp( + State, BOK, FirstPtrVal, SecondPtrVal, Call.getResultType()); + State = State->BindExpr(Call.getOriginExpr(), C.getLocationContext(), RetVal); + C.addTransition(State); + return true; +} + void SmartPtrModeling::checkDeadSymbols(SymbolReaper &SymReaper, CheckerContext &C) const { ProgramStateRef State = C.getState(); diff --git a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h --- a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h +++ b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h @@ -978,6 +978,61 @@ void swap(unique_ptr &x, unique_ptr &y) noexcept { x.swap(y); } + +template +bool operator==(const unique_ptr &x, const unique_ptr &y); + +template +bool operator!=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator<(const unique_ptr &x, const unique_ptr &y); + +template +bool operator>(const unique_ptr &x, const unique_ptr &y); + +template +bool operator<=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator>=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator==(const unique_ptr &x, nullptr_t y); + +template +bool operator!=(const unique_ptr &x, nullptr_t y); + +template +bool operator<(const unique_ptr &x, nullptr_t y); + +template +bool operator>(const unique_ptr &x, nullptr_t y); + +template +bool operator<=(const unique_ptr &x, nullptr_t y); + +template +bool operator>=(const unique_ptr &x, nullptr_t y); + +template +bool operator==(nullptr_t x, const unique_ptr &y); + +template +bool operator!=(nullptr_t x, const unique_ptr &y); + +template +bool operator>(nullptr_t x, const unique_ptr &y); + +template +bool operator<(nullptr_t x, const unique_ptr &y); + +template +bool operator>=(nullptr_t x, const unique_ptr &y); + +template +bool operator<=(nullptr_t x, const unique_ptr &y); + } // namespace std #endif diff --git a/clang/test/Analysis/smart-ptr.cpp b/clang/test/Analysis/smart-ptr.cpp --- a/clang/test/Analysis/smart-ptr.cpp +++ b/clang/test/Analysis/smart-ptr.cpp @@ -457,3 +457,32 @@ P->foo(); // expected-warning {{Dereference of null smart pointer 'P' [alpha.cplusplus.SmartPtr]}} } } + +// The following is a silly function, +// but serves to test if we are picking out +// standard comparision functions from custom ones. +template +bool operator<(std::unique_ptr &x, double d); + +void uniquePtrComparision(std::unique_ptr unknownPtr) { + auto ptr = std::unique_ptr(new int(13)); + auto nullPtr = std::unique_ptr(); + auto otherPtr = std::unique_ptr(new int(29)); + + clang_analyzer_eval(ptr == ptr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > ptr); // expected-warning{{FALSE}} + clang_analyzer_eval(ptr <= ptr); // expected-warning{{TRUE}} + + clang_analyzer_eval(nullPtr <= unknownPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(unknownPtr >= nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != otherPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != nullptr); // expected-warning{{TRUE}} + clang_analyzer_eval(nullPtr != nullptr); // expected-warning{{FALSE}} + clang_analyzer_eval(nullptr <= unknownPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(nullptr < unknownPtr); // expected-warning{{UNKNOWN}} + + clang_analyzer_eval(ptr < 2.0); // expected-warning{{UNKNOWN}} +}