diff --git a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CheckerHelpers.h b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CheckerHelpers.h --- a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CheckerHelpers.h +++ b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CheckerHelpers.h @@ -13,7 +13,9 @@ #ifndef LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_CHECKERHELPERS_H #define LLVM_CLANG_STATICANALYZER_CORE_PATHSENSITIVE_CHECKERHELPERS_H +#include "clang/AST/OperationKinds.h" #include "clang/AST/Stmt.h" +#include "clang/Basic/OperatorKinds.h" #include "llvm/ADT/Optional.h" #include @@ -69,6 +71,45 @@ /// token for an integer. If we cannot parse the value then None is returned. llvm::Optional tryExpandAsInteger(StringRef Macro, const Preprocessor &PP); +class OperatorKind { + union { + BinaryOperatorKind Bin; + UnaryOperatorKind Un; + } Op; + bool IsBinary; + +public: + explicit OperatorKind(BinaryOperatorKind Bin) : Op{Bin}, IsBinary{true} {} + explicit OperatorKind(UnaryOperatorKind Un) : IsBinary{false} { Op.Un = Un; } + bool IsBinaryOp() const { return IsBinary; } + + BinaryOperatorKind GetBinaryOpUnsafe() const { + assert(IsBinary && "cannot get binary operator - we have a unary operator"); + return Op.Bin; + } + + Optional GetBinaryOp() const { + if (IsBinary) + return Op.Bin; + return {}; + } + + UnaryOperatorKind GetUnaryOpUnsafe() const { + assert(!IsBinary && + "cannot get unary operator - we have a binary operator"); + return Op.Un; + } + + Optional GetUnaryOp() const { + if (!IsBinary) + return Op.Un; + return {}; + } +}; + +OperatorKind operationKindFromOverloadedOperator(OverloadedOperatorKind OOK, + bool IsBinary); + } // namespace ento } // namespace clang diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h b/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h @@ -22,6 +22,8 @@ /// Returns true if the event call is on smart pointer. bool isStdSmartPtrCall(const CallEvent &Call); +bool isStdSmartPtr(const CXXRecordDecl *RD); +bool isStdSmartPtr(const Expr *E); /// Returns whether the smart pointer is null or not. bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion); diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp @@ -25,10 +25,13 @@ #include "clang/StaticAnalyzer/Core/CheckerManager.h" #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" +#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerHelpers.h" #include "clang/StaticAnalyzer/Core/PathSensitive/MemRegion.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SymExpr.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SymbolManager.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/ErrorHandling.h" #include using namespace clang; @@ -68,6 +71,10 @@ bool updateMovedSmartPointers(CheckerContext &C, const MemRegion *ThisRegion, const MemRegion *OtherSmartPtrRegion) const; void handleBoolConversion(const CallEvent &Call, CheckerContext &C) const; + bool handleComparisionOp(const CallEvent &Call, CheckerContext &C) const; + std::pair + retrieveOrConjureInnerPtrVal(const MemRegion *ThisRegion, const Expr *E, + QualType Type, CheckerContext &C) const; using SmartPtrMethodHandlerFn = void (SmartPtrModeling::*)(const CallEvent &Call, CheckerContext &) const; @@ -89,18 +96,24 @@ const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); if (!MethodDecl || !MethodDecl->getParent()) return false; + return isStdSmartPtr(MethodDecl->getParent()); +} - const auto *RecordDecl = MethodDecl->getParent(); - if (!RecordDecl || !RecordDecl->getDeclContext()->isStdNamespace()) +bool isStdSmartPtr(const CXXRecordDecl *RD) { + if (!RD || !RD->getDeclContext()->isStdNamespace()) return false; - if (RecordDecl->getDeclName().isIdentifier()) { - StringRef Name = RecordDecl->getName(); + if (RD->getDeclName().isIdentifier()) { + StringRef Name = RD->getName(); return Name == "shared_ptr" || Name == "unique_ptr" || Name == "weak_ptr"; } return false; } +bool isStdSmartPtr(const Expr *E) { + return isStdSmartPtr(E->getType()->getAsCXXRecordDecl()); +} + bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion) { const auto *InnerPointVal = State->get(ThisRegion); return InnerPointVal && @@ -135,18 +148,11 @@ return State; } -// Helper method to get the inner pointer type of specialized smart pointer -// Returns empty type if not found valid inner pointer type. -static QualType getInnerPointerType(const CallEvent &Call, CheckerContext &C) { - const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); - if (!MethodDecl || !MethodDecl->getParent()) - return {}; - - const auto *RecordDecl = MethodDecl->getParent(); - if (!RecordDecl || !RecordDecl->isInStdNamespace()) +static QualType getInnerPointerType(CheckerContext C, const CXXRecordDecl *RD) { + if (!RD || !RD->isInStdNamespace()) return {}; - const auto *TSD = dyn_cast(RecordDecl); + const auto *TSD = dyn_cast(RD); if (!TSD) return {}; @@ -157,6 +163,17 @@ return C.getASTContext().getPointerType(InnerValueType.getCanonicalType()); } +// Helper method to get the inner pointer type of specialized smart pointer +// Returns empty type if not found valid inner pointer type. +static QualType getInnerPointerType(const CallEvent &Call, CheckerContext &C) { + const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); + if (!MethodDecl || !MethodDecl->getParent()) + return {}; + + const auto *RecordDecl = MethodDecl->getParent(); + return getInnerPointerType(C, RecordDecl); +} + // Helper method to pretty print region and avoid extra spacing. static void checkAndPrettyPrintRegion(llvm::raw_ostream &OS, const MemRegion *Region) { @@ -178,6 +195,16 @@ bool SmartPtrModeling::evalCall(const CallEvent &Call, CheckerContext &C) const { ProgramStateRef State = C.getState(); + + // If any one of the arg is a unique_ptr, then + // we can try this function + if (Call.getNumArgs() == 2 && + Call.getDecl()->getDeclContext()->isStdNamespace()) + if (smartptr::isStdSmartPtr(Call.getArgExpr(0)) || + smartptr::isStdSmartPtr(Call.getArgExpr(1))) + if (handleComparisionOp(Call, C)) + return true; + if (!smartptr::isStdSmartPtrCall(Call)) return false; @@ -272,6 +299,91 @@ return C.isDifferent(); } +std::pair +SmartPtrModeling::retrieveOrConjureInnerPtrVal(const MemRegion *ThisRegion, + const Expr *E, QualType Type, + CheckerContext &C) const { + ProgramStateRef State = C.getState(); + const auto *Ptr = State->get(ThisRegion); + if (Ptr) + return {*Ptr, State}; + auto Val = C.getSValBuilder().conjureSymbolVal(E, C.getLocationContext(), + Type, C.blockCount()); + State = State->set(ThisRegion, Val); + return {Val, State}; +} + +bool SmartPtrModeling::handleComparisionOp(const CallEvent &Call, + CheckerContext &C) const { + const auto *FC = dyn_cast(&Call); + if (!FC) + return false; + const FunctionDecl *FD = FC->getDecl(); + if (!FD->isOverloadedOperator()) + return false; + const OverloadedOperatorKind OOK = FD->getOverloadedOperator(); + if (!(OOK == OO_EqualEqual || OOK == OO_ExclaimEqual || OOK == OO_Less || + OOK == OO_LessEqual || OOK == OO_Greater || OOK == OO_GreaterEqual || + OOK == OO_Spaceship)) + return false; + + // There are some special cases about which we can infer about + // the resulting answer. + // For reference, there is a discussion at https://reviews.llvm.org/D104616. + // Also, the cppreference page is good to look at + // https://en.cppreference.com/w/cpp/memory/unique_ptr/operator_cmp. + + ProgramStateRef State = C.getState(); + + auto makeSValFor = [&C, State, + this](const Expr *E, + SVal S) -> std::pair { + if (S.isZeroConstant()) { + return {S, State}; + } + const MemRegion *Reg = S.getAsRegion(); + assert(Reg && + "this pointer of std::unique_ptr should be obtainable as MemRegion"); + SVal Val; + ProgramStateRef NewState; + QualType Type = getInnerPointerType(C, E->getType()->getAsCXXRecordDecl()); + std::tie(Val, NewState) = retrieveOrConjureInnerPtrVal(Reg, E, Type, C); + return {Val, NewState}; + }; + + SVal First = Call.getArgSVal(0); + SVal Second = Call.getArgSVal(1); + const auto *FirstExpr = Call.getArgExpr(0); + const auto *SecondExpr = Call.getArgExpr(1); + + const auto *ResultExpr = Call.getOriginExpr(); + const auto *LCtx = C.getLocationContext(); + auto &Bldr = C.getSValBuilder(); + + SVal FirstPtrVal, SecondPtrVal; + std::tie(FirstPtrVal, State) = makeSValFor(FirstExpr, First); + std::tie(SecondPtrVal, State) = makeSValFor(SecondExpr, Second); + BinaryOperatorKind BOK = + operationKindFromOverloadedOperator(OOK, true).GetBinaryOpUnsafe(); + auto RetVal = Bldr.evalBinOp(State, BOK, FirstPtrVal, SecondPtrVal, + Call.getResultType()); + + if (OOK != OO_Spaceship) { + ProgramStateRef TrueState, FalseState; + std::tie(TrueState, FalseState) = + State->assume(*RetVal.getAs()); + if (TrueState) + C.addTransition( + TrueState->BindExpr(ResultExpr, LCtx, Bldr.makeTruthVal(true))); + if (FalseState) + C.addTransition( + FalseState->BindExpr(ResultExpr, LCtx, Bldr.makeTruthVal(false))); + } else { + C.addTransition(State->BindExpr(ResultExpr, LCtx, RetVal)); + } + return true; +} + void SmartPtrModeling::checkDeadSymbols(SymbolReaper &SymReaper, CheckerContext &C) const { ProgramStateRef State = C.getState(); @@ -446,15 +558,8 @@ return; SVal InnerPointerVal; - if (const auto *InnerValPtr = State->get(ThisRegion)) { - InnerPointerVal = *InnerValPtr; - } else { - const auto *CallExpr = Call.getOriginExpr(); - InnerPointerVal = C.getSValBuilder().conjureSymbolVal( - CallExpr, C.getLocationContext(), Call.getResultType(), C.blockCount()); - State = State->set(ThisRegion, InnerPointerVal); - } - + std::tie(InnerPointerVal, State) = retrieveOrConjureInnerPtrVal( + ThisRegion, Call.getOriginExpr(), Call.getResultType(), C); State = State->BindExpr(Call.getOriginExpr(), C.getLocationContext(), InnerPointerVal); // TODO: Add NoteTag, for how the raw pointer got using 'get' method. diff --git a/clang/lib/StaticAnalyzer/Core/CheckerHelpers.cpp b/clang/lib/StaticAnalyzer/Core/CheckerHelpers.cpp --- a/clang/lib/StaticAnalyzer/Core/CheckerHelpers.cpp +++ b/clang/lib/StaticAnalyzer/Core/CheckerHelpers.cpp @@ -148,5 +148,39 @@ return IntValue.getSExtValue(); } +OperatorKind operationKindFromOverloadedOperator(OverloadedOperatorKind OOK, + bool IsBinary) { + llvm::StringMap BinOps{ +#define BINARY_OPERATION(Name, Spelling) {Spelling, BO_##Name}, +#include "clang/AST/OperationKinds.def" + }; + llvm::StringMap UnOps{ +#define UNARY_OPERATION(Name, Spelling) {Spelling, UO_##Name}, +#include "clang/AST/OperationKinds.def" + }; + + switch (OOK) { +#define OVERLOADED_OPERATOR(Name, Spelling, Token, Unary, Binary, MemberOnly) \ + case OO_##Name: \ + if (IsBinary) { \ + auto BinOpIt = BinOps.find(Spelling); \ + if (BinOpIt != BinOps.end()) \ + return OperatorKind(BinOpIt->second); \ + else \ + llvm_unreachable("operator was expected to be binary but is not"); \ + } else { \ + auto UnOpIt = UnOps.find(Spelling); \ + if (UnOpIt != UnOps.end()) \ + return OperatorKind(UnOpIt->second); \ + else \ + llvm_unreachable("operator was expected to be unary but is not"); \ + } \ + break; +#include "clang/Basic/OperatorKinds.def" + default: + llvm_unreachable("unexpected operator kind"); + } +} + } // namespace ento } // namespace clang diff --git a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h --- a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h +++ b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h @@ -978,6 +978,61 @@ void swap(unique_ptr &x, unique_ptr &y) noexcept { x.swap(y); } + +template +bool operator==(const unique_ptr &x, const unique_ptr &y); + +template +bool operator!=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator<(const unique_ptr &x, const unique_ptr &y); + +template +bool operator>(const unique_ptr &x, const unique_ptr &y); + +template +bool operator<=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator>=(const unique_ptr &x, const unique_ptr &y); + +template +bool operator==(const unique_ptr &x, nullptr_t y); + +template +bool operator!=(const unique_ptr &x, nullptr_t y); + +template +bool operator<(const unique_ptr &x, nullptr_t y); + +template +bool operator>(const unique_ptr &x, nullptr_t y); + +template +bool operator<=(const unique_ptr &x, nullptr_t y); + +template +bool operator>=(const unique_ptr &x, nullptr_t y); + +template +bool operator==(nullptr_t x, const unique_ptr &y); + +template +bool operator!=(nullptr_t x, const unique_ptr &y); + +template +bool operator>(nullptr_t x, const unique_ptr &y); + +template +bool operator<(nullptr_t x, const unique_ptr &y); + +template +bool operator>=(nullptr_t x, const unique_ptr &y); + +template +bool operator<=(nullptr_t x, const unique_ptr &y); + } // namespace std #endif diff --git a/clang/test/Analysis/smart-ptr.cpp b/clang/test/Analysis/smart-ptr.cpp --- a/clang/test/Analysis/smart-ptr.cpp +++ b/clang/test/Analysis/smart-ptr.cpp @@ -457,3 +457,52 @@ P->foo(); // expected-warning {{Dereference of null smart pointer 'P' [alpha.cplusplus.SmartPtr]}} } } + +// The following is a silly function, +// but serves to test if we are picking out +// standard comparision functions from custom ones. +template +bool operator<(std::unique_ptr &x, double d); + +void uniquePtrComparision(std::unique_ptr unknownPtr) { + auto ptr = std::unique_ptr(new int(13)); + auto nullPtr = std::unique_ptr(); + auto otherPtr = std::unique_ptr(new int(29)); + + clang_analyzer_eval(ptr == ptr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > ptr); // expected-warning{{FALSE}} + clang_analyzer_eval(ptr <= ptr); // expected-warning{{TRUE}} + + clang_analyzer_eval(nullPtr <= unknownPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(unknownPtr >= nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != otherPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != nullptr); // expected-warning{{TRUE}} + clang_analyzer_eval(nullPtr != nullptr); // expected-warning{{FALSE}} + clang_analyzer_eval(nullptr <= unknownPtr); // expected-warning{{TRUE}} +} + +void uniquePtrComparisionStateSplitting(std::unique_ptr unknownPtr) { + auto ptr = std::unique_ptr(new int(13)); + + clang_analyzer_eval(ptr > unknownPtr); // expected-warning{{TRUE}} + // expected-warning@-1{{FALSE}} +} + +void uniquePtrComparisionDifferingTypes(std::unique_ptr unknownPtr) { + auto ptr = std::unique_ptr(new int(13)); + auto nullPtr = std::unique_ptr(); + auto otherPtr = std::unique_ptr(new double(3.14)); + + clang_analyzer_eval(nullPtr <= unknownPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(unknownPtr >= nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != otherPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != nullptr); // expected-warning{{TRUE}} + clang_analyzer_eval(nullPtr != nullptr); // expected-warning{{FALSE}} + clang_analyzer_eval(nullptr <= unknownPtr); // expected-warning{{TRUE}} +}