diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h b/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtr.h @@ -22,6 +22,8 @@ /// Returns true if the event call is on smart pointer. bool isStdSmartPtrCall(const CallEvent &Call); +bool isStdSmartPtr(const CXXRecordDecl *RD); +bool isStdSmartPtr(const Expr *E); /// Returns whether the smart pointer is null or not. bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion); diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp @@ -29,6 +29,7 @@ #include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SymExpr.h" #include "clang/StaticAnalyzer/Core/PathSensitive/SymbolManager.h" +#include "llvm/Support/ErrorHandling.h" #include using namespace clang; @@ -68,6 +69,7 @@ bool updateMovedSmartPointers(CheckerContext &C, const MemRegion *ThisRegion, const MemRegion *OtherSmartPtrRegion) const; void handleBoolConversion(const CallEvent &Call, CheckerContext &C) const; + bool handleComparisionOp(const CallEvent &Call, CheckerContext &C) const; using SmartPtrMethodHandlerFn = void (SmartPtrModeling::*)(const CallEvent &Call, CheckerContext &) const; @@ -89,18 +91,24 @@ const auto *MethodDecl = dyn_cast_or_null(Call.getDecl()); if (!MethodDecl || !MethodDecl->getParent()) return false; + return isStdSmartPtr(MethodDecl->getParent()); +} - const auto *RecordDecl = MethodDecl->getParent(); - if (!RecordDecl || !RecordDecl->getDeclContext()->isStdNamespace()) +bool isStdSmartPtr(const CXXRecordDecl *RD) { + if (!RD || !RD->getDeclContext()->isStdNamespace()) return false; - if (RecordDecl->getDeclName().isIdentifier()) { - StringRef Name = RecordDecl->getName(); + if (RD->getDeclName().isIdentifier()) { + StringRef Name = RD->getName(); return Name == "shared_ptr" || Name == "unique_ptr" || Name == "weak_ptr"; } return false; } +bool isStdSmartPtr(const Expr *E) { + return isStdSmartPtr(E->getType()->getAsCXXRecordDecl()); +} + bool isNullSmartPtr(const ProgramStateRef State, const MemRegion *ThisRegion) { const auto *InnerPointVal = State->get(ThisRegion); return InnerPointVal && @@ -178,6 +186,15 @@ bool SmartPtrModeling::evalCall(const CallEvent &Call, CheckerContext &C) const { ProgramStateRef State = C.getState(); + + // If any one of the arg is a unique_ptr, then + // we can try this function + if (Call.getNumArgs() == 2) + if (smartptr::isStdSmartPtr(Call.getArgExpr(0)) || + smartptr::isStdSmartPtr(Call.getArgExpr(1))) + if (handleComparisionOp(Call, C)) + return true; + if (!smartptr::isStdSmartPtrCall(Call)) return false; @@ -272,6 +289,123 @@ return C.isDifferent(); } +bool SmartPtrModeling::handleComparisionOp(const CallEvent &Call, + CheckerContext &C) const { + const auto *FC = dyn_cast(&Call); + if (!FC) + return false; + const FunctionDecl *FD = FC->getDecl(); + if (!FD->isOverloadedOperator()) + return false; + const OverloadedOperatorKind OOK = FD->getOverloadedOperator(); + if (!(OOK == OO_Equal || OOK == OO_ExclaimEqual || OOK == OO_Less || + OOK == OO_LessEqual || OOK == OO_Greater || OOK == OO_GreaterEqual || + OOK == OO_Spaceship)) + return false; + + SVal First = Call.getArgSVal(0); + SVal Second = Call.getArgSVal(1); + + // There are some special cases about which we can infer about + // the resulting answer. + // For reference, there is a discussion at https://reviews.llvm.org/D104616. + // Also, the cppreference page is good to look at + // https://en.cppreference.com/w/cpp/memory/unique_ptr/operator_cmp. + + ProgramStateRef State = C.getState(); + SVal RetVal = C.getSValBuilder().conjureSymbolVal( + Call.getOriginExpr(), C.getLocationContext(), Call.getResultType(), + C.blockCount()); + + if (!First.isZeroConstant() && !Second.isZeroConstant()) { + // Neither are nullptr, so they are both std::unique_ptr. (whether the smart + // pointers are null or not is an entire different question.) + const MemRegion *FirstReg = First.getAsRegion(); + const MemRegion *SecondReg = Second.getAsRegion(); + if (!FirstReg || !SecondReg) + return false; + + // First and Second may refer to the same object + if (FirstReg == SecondReg) { + switch (OOK) { + case OO_Equal: + case OO_GreaterEqual: + case OO_LessEqual: + State = State->assume((&RetVal)->castAs(), true); + break; + case OO_Greater: + case OO_Less: + State = State->assume((&RetVal)->castAs(), false); + break; + case OO_Spaceship: + // TODO: What would be a good thing to do here? + default: + llvm_unreachable("cannot reach here"); + } + } else { + const auto *FirstPtr = State->get(FirstReg); + const auto *SecondPtr = State->get(SecondReg); + + if (!FirstPtr && !SecondPtr) + return false; + + // Now, we know the inner pointer of at least one + + if (FirstPtr && !SecondPtr && + State->assume(FirstPtr->castAs(), false)) { + // FirstPtr is null, SecondPtr is unknown + if (OOK == OO_LessEqual) + State = + State->assume((&RetVal)->castAs(), true); + } + if (SecondPtr && !FirstPtr && + State->assume(SecondPtr->castAs(), false)) { + // SecondPtr is null, FirstPtr is unknown + if (OOK == OO_GreaterEqual) + State = + State->assume((&RetVal)->castAs(), true); + } + + if (FirstPtr && SecondPtr) { + BinaryOperatorKind BOK; + switch (OOK) { + case OO_Equal: + BOK = BO_EQ; + break; + case OO_ExclaimEqual: + BOK = BO_NE; + break; + case OO_GreaterEqual: + BOK = BO_GE; + break; + case OO_LessEqual: + BOK = BO_LE; + break; + case OO_Less: + BOK = BO_LT; + break; + case OO_Greater: + BOK = BO_GT; + break; + case OO_Spaceship: + BOK = BO_Cmp; + break; + default: + llvm_unreachable("cannot reach here"); + } + RetVal = C.getSValBuilder().evalBinOp(State, BOK, *FirstPtr, *SecondPtr, + Call.getResultType()); + } + } + } + + // TODO: Now handle all the cases with one arg as nullptr + + State = State->BindExpr(Call.getOriginExpr(), C.getLocationContext(), RetVal); + C.addTransition(State); + return true; +} + void SmartPtrModeling::checkDeadSymbols(SymbolReaper &SymReaper, CheckerContext &C) const { ProgramStateRef State = C.getState(); diff --git a/clang/test/Analysis/smart-ptr.cpp b/clang/test/Analysis/smart-ptr.cpp --- a/clang/test/Analysis/smart-ptr.cpp +++ b/clang/test/Analysis/smart-ptr.cpp @@ -457,3 +457,19 @@ P->foo(); // expected-warning {{Dereference of null smart pointer 'P' [alpha.cplusplus.SmartPtr]}} } } + +void uniquePtrComparision(std::unique_ptr unknownPtr) { + auto ptr = std::unique_ptr(new int(13)); + auto nullPtr = std::unique_ptr(); + auto otherPtr = std::unique_ptr(new int(29)); + + clang_analyzer_eval(ptr == ptr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > ptr); // expected-warning{{FALSE}} + clang_analyzer_eval(ptr <= ptr); // expected-warning{{TRUE}} + + clang_analyzer_eval(nullPtr <= unknownPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(unknownPtr >= nullPtr); // expected-warning{{TRUE}} + + clang_analyzer_eval(ptr != otherPtr); // expected-warning{{TRUE}} + clang_analyzer_eval(ptr > nullPtr); // expected-warning{{TRUE}} +}