diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h b/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h @@ -22,6 +22,8 @@ /// Returns true if the event call is on smart pointer. bool isStdSmartPtrCall(const CallEvent &Call); +bool isStdSmartPtr(const CXXRecordDecl *RD); +bool isStdSmartPtr(const Expr *E); /// Returns whether the smart pointer is null or not. bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion); diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp @@ -29,7 +29,10 @@ #include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SymExpr.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SymbolManager.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/ErrorHandling.h" #include +#include using namespace clang; using namespace ento; @@ -68,6 +71,7 @@ bool updateMovedSmartPointers(CheckerContext &C, const MemRegion *ThisRegion, const MemRegion *OtherSmartPtrRegion) const; void handleBoolConversion(const CallEvent &Call, CheckerContext &C) const; + bool handleComparisionOp(const CallEvent &Call, CheckerContext &C) const; using SmartPtrMethodHandlerFn = void (SmartPtrModeling::*)(const CallEvent &Call, CheckerContext &) const; @@ -89,18 +93,24 @@ const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); if (!MethodDecl || !MethodDecl->getParent()) return false; + return isStdSmartPtr(MethodDecl->getParent()); +} - const auto *RecordDecl = MethodDecl->getParent(); - if (!RecordDecl || !RecordDecl->getDeclContext()->isStdNamespace()) +bool isStdSmartPtr(const CXXRecordDecl *RD) { + if (!RD || !RD->getDeclContext()->isStdNamespace()) return false; - if (RecordDecl->getDeclName().isIdentifier()) { - StringRef Name = RecordDecl->getName(); + if (RD->getDeclName().isIdentifier()) { + StringRef Name = RD->getName(); return Name == "shared_ptr" || Name == "unique_ptr" || Name == "weak_ptr"; } return false; } +bool isStdSmartPtr(const Expr *E) { + return isStdSmartPtr(E->getType()->getAsCXXRecordDecl()); +} + bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion) { const auto *InnerPointVal = State->get(ThisRegion); return InnerPointVal && @@ -135,18 +145,11 @@ return State; } -// Helper method to get the inner pointer type of specialized smart pointer -// Returns empty type if not found valid inner pointer type. -static QualType getInnerPointerType(const CallEvent &Call, CheckerContext &C) { - const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); - if (!MethodDecl || !MethodDecl->getParent()) - return {}; - - const auto *RecordDecl = MethodDecl->getParent(); - if (!RecordDecl || !RecordDecl->isInStdNamespace()) +static QualType getInnerPointerType(CheckerContext C, const CXXRecordDecl *RD) { + if (!RD || !RD->isInStdNamespace()) return {}; - const auto *TSD = dyn_cast(RecordDecl); + const auto *TSD = dyn_cast(RD); if (!TSD) return {}; @@ -157,6 +160,17 @@ return C.getASTContext().getPointerType(InnerValueType.getCanonicalType()); } +// Helper method to get the inner pointer type of specialized smart pointer +// Returns empty type if not found valid inner pointer type. +static QualType getInnerPointerType(const CallEvent &Call, CheckerContext &C) { + const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); + if (!MethodDecl || !MethodDecl->getParent()) + return {}; + + const auto *RecordDecl = MethodDecl->getParent(); + return getInnerPointerType(C, RecordDecl); +} + // Helper method to pretty print region and avoid extra spacing. static void checkAndPrettyPrintRegion(llvm::raw_ostream &OS, const MemRegion *Region) { @@ -178,6 +192,16 @@ bool SmartPtrModeling::evalCall(const CallEvent &Call, CheckerContext &C) const { ProgramStateRef State = C.getState(); + + // If any one of the arg is a unique_ptr, then + // we can try this function + if (Call.getNumArgs() == 2 && + Call.getDecl()->getDeclContext()->isStdNamespace()) + if (smartptr::isStdSmartPtr(Call.getArgExpr(0)) || + smartptr::isStdSmartPtr(Call.getArgExpr(1))) + if (handleComparisionOp(Call, C)) + return true; + if (!smartptr::isStdSmartPtrCall(Call)) return false; @@ -272,6 +296,132 @@ return C.isDifferent(); } +class OperatorKind { + union { + BinaryOperatorKind Bin; + UnaryOperatorKind Un; + } Op; + bool IsBinary; + +public: + explicit OperatorKind(BinaryOperatorKind Bin) : Op{Bin}, IsBinary{true} {} + explicit OperatorKind(UnaryOperatorKind Un) : IsBinary{false} { Op.Un = Un; } + bool IsBinaryOp() const { return IsBinary; } + + BinaryOperatorKind GetBinaryOpUnsafe() const { + assert(IsBinary && "cannot get binary operator - we have a unary operator"); + return Op.Bin; + } + + Optional GetBinaryOp() const { + if (IsBinary) + return Op.Bin; + return {}; + } + + UnaryOperatorKind GetUnaryOpUnsafe() const { + assert(!IsBinary && + "cannot get unary operator - we have a binary operator"); + return Op.Un; + } + + Optional GetUnaryOp() const { + if (!IsBinary) + return Op.Un; + return {}; + } +}; + +OperatorKind operationKindFromOverloadedOperator(OverloadedOperatorKind OOK) { + llvm::StringMap BinOps{ +#define BINARY_OPERATION(Name, Spelling) {Spelling, BO_##Name}, +#include "clang/AST/OperationKinds.def" + }; + llvm::StringMap UnOps{ +#define UNARY_OPERATION(Name, Spelling) {Spelling, UO_##Name}, +#include "clang/AST/OperationKinds.def" + }; + + llvm::StringMap::iterator BinOpIt; + llvm::StringMap::iterator UnOpIt; + + switch (OOK) { +#define OVERLOADED_OPERATOR(Name, Spelling, Token, Unary, Binary, MemberOnly) \ + case OO_##Name: \ + BinOpIt = BinOps.find(Spelling); \ + if (BinOpIt != BinOps.end()) { \ + return OperatorKind(BinOpIt->second); \ + } else { \ + UnOpIt = UnOps.find(Spelling); \ + if (UnOpIt != UnOps.end()) \ + return OperatorKind(UnOpIt->second); \ + else \ + llvm_unreachable("expected operator to be either unary or binary"); \ + } \ + break; +#include "clang/Basic/OperatorKinds.def" + default: + llvm_unreachable("unexpected operator kind"); + } +} + +bool SmartPtrModeling::handleComparisionOp(const CallEvent &Call, + CheckerContext &C) const { + const auto *FC = dyn_cast(&Call); + if (!FC) + return false; + const FunctionDecl *FD = FC->getDecl(); + if (!FD->isOverloadedOperator()) + return false; + const OverloadedOperatorKind OOK = FD->getOverloadedOperator(); + if (!(OOK == OO_EqualEqual || OOK == OO_ExclaimEqual || OOK == OO_Less || + OOK == OO_LessEqual || OOK == OO_Greater || OOK == OO_GreaterEqual || + OOK == OO_Spaceship)) + return false; + + // There are some special cases about which we can infer about + // the resulting answer. + // For reference, there is a discussion at https://reviews.llvm.org/D104616. + // Also, the cppreference page is good to look at + // https://en.cppreference.com/w/cpp/memory/unique_ptr/operator_cmp. + + ProgramStateRef State = C.getState(); + + auto retrieveOrConjureInnerPtrVal = [&C, &State](const Expr *E, + SVal S) -> SVal { + if (S.isZeroConstant()) { + return C.getSValBuilder().makeZeroVal(E->getType()); + } + const MemRegion *Reg = S.getAsRegion(); + assert(Reg && + "this pointer of std::unique_ptr should be obtainable as MemRegion"); + const auto *Ptr = State->get(Reg); + if (Ptr) + return *Ptr; + auto Val = C.getSValBuilder().conjureSymbolVal( + E, C.getLocationContext(), + getInnerPointerType(C, E->getType()->getAsCXXRecordDecl()), + C.blockCount()); + State = State->set(Reg, Val); + return Val; + }; + + SVal First = Call.getArgSVal(0); + SVal Second = Call.getArgSVal(1); + const auto *FirstExpr = Call.getArgExpr(0); + const auto *SecondExpr = Call.getArgExpr(1); + + SVal FirstPtrVal = retrieveOrConjureInnerPtrVal(FirstExpr, First); + SVal SecondPtrVal = retrieveOrConjureInnerPtrVal(SecondExpr, Second); + BinaryOperatorKind BOK = + operationKindFromOverloadedOperator(OOK).GetBinaryOpUnsafe(); + auto RetVal = C.getSValBuilder().evalBinOp( + State, BOK, FirstPtrVal, SecondPtrVal, Call.getResultType()); + State = State->BindExpr(Call.getOriginExpr(), C.getLocationContext(), RetVal); + C.addTransition(State); + return true; +} + void SmartPtrModeling::checkDeadSymbols(SymbolReaper &SymReaper, CheckerContext &C) const { ProgramStateRef State = C.getState(); diff --git a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h --- a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h +++ b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h @@ -978,6 +978,61 @@ void swap(unique_ptr &x, unique_ptr &y) noexcept { x.swap(y); } + +template +bool operator==(const unique_ptr &x, const unique_ptr &y); + +template +bool operator!=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator<(const unique_ptr &x, const unique_ptr &y); + +template +bool operator>(const unique_ptr &x, const unique_ptr &y); + +template +bool operator<=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator>=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator==(const unique_ptr &x, nullptr_t y); + +template +bool operator!=(const unique_ptr &x, nullptr_t y); + +template +bool operator<(const unique_ptr &x, nullptr_t y); + +template +bool operator>(const unique_ptr &x, nullptr_t y); + +template +bool operator<=(const unique_ptr &x, nullptr_t y); + +template +bool operator>=(const unique_ptr &x, nullptr_t y); + +template +bool operator==(nullptr_t x, const unique_ptr &y); + +template +bool operator!=(nullptr_t x, const unique_ptr &y); + +template +bool operator>(nullptr_t x, const unique_ptr &y); + +template +bool operator<(nullptr_t x, const unique_ptr &y); + +template +bool operator>=(nullptr_t x, const unique_ptr &y); + +template +bool operator<=(nullptr_t x, const unique_ptr &y); + } // namespace std #endif diff --git a/clang/test/Analysis/smart-ptr.cpp b/clang/test/Analysis/smart-ptr.cpp --- a/clang/test/Analysis/smart-ptr.cpp +++ b/clang/test/Analysis/smart-ptr.cpp @@ -457,3 +457,32 @@ P->foo(); // expected-warning {{Dereference of null smart pointer 'P' [alpha.cplusplus.SmartPtr]}} } } + +// The following is a silly function, +// but serves to test if we are picking out +// standard comparision functions from custom ones. +template +bool operator<(std::unique_ptr &x, double d); + +void uniquePtrComparision(std::unique_ptr unknownPtr) { + auto ptr = std::unique_ptr(new int(13)); + auto nullPtr = std::unique_ptr(); + auto otherPtr = std::unique_ptr(new int(29)); + + clang_analyzer_eval(ptr == ptr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > ptr); // expected-warning{{FALSE}} + clang_analyzer_eval(ptr <= ptr); // expected-warning{{TRUE}} + + clang_analyzer_eval(nullPtr <= unknownPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(unknownPtr >= nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != otherPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != nullptr); // expected-warning{{TRUE}} + clang_analyzer_eval(nullPtr != nullptr); // expected-warning{{FALSE}} + clang_analyzer_eval(nullptr <= unknownPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(nullptr < unknownPtr); // expected-warning{{UNKNOWN}} + + clang_analyzer_eval(ptr < 2.0); // expected-warning{{UNKNOWN}} +}