diff --git a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp --- a/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/SmartPtrModeling.cpp @@ -35,11 +35,19 @@ using namespace ento; namespace { + +enum class MakeUniqueKind { + Regular, // ie, std::make_unique + ForOverwrite, // ie, std::make_unique_for_overwrite + None // ie, is neither of the above two +}; + class SmartPtrModeling : public Checker { + check::LiveSymbols, check::Bind> { bool isBoolConversionMethod(const CallEvent &Call) const; + MakeUniqueKind isStdMakeUniqueCall(const CallEvent &Call) const; public: // Whether the checker should model for null dereferences of smart pointers. @@ -56,6 +64,7 @@ void printState(raw_ostream &Out, ProgramStateRef State, const char *NL, const char *Sep) const override; void checkLiveSymbols(ProgramStateRef State, SymbolReaper &SR) const; + void checkBind(SVal L, SVal V, const Stmt *S, CheckerContext &C) const; private: void handleReset(const CallEvent &Call, CheckerContext &C) const; @@ -68,6 +77,8 @@ bool updateMovedSmartPointers(CheckerContext &C, const MemRegion *ThisRegion, const MemRegion *OtherSmartPtrRegion) const; void handleBoolConversion(const CallEvent &Call, CheckerContext &C) const; + void handleStdMakeUnique(const CallEvent &Call, CheckerContext &C, + MakeUniqueKind Kind) const; using SmartPtrMethodHandlerFn = void (SmartPtrModeling::*)(const CallEvent &Call, CheckerContext &) const; @@ -79,7 +90,21 @@ }; } // end of anonymous namespace +class MakeUniqueKindWrapper { + const MakeUniqueKind Kind; + +public: + MakeUniqueKindWrapper(MakeUniqueKind Kind) : Kind(Kind) {} + MakeUniqueKind get() const { return Kind; } + void Profile(llvm::FoldingSetNodeID &ID) const { + ID.AddInteger(static_cast(Kind)); + } + bool operator==(const MakeUniqueKind &RHS) const { return Kind == RHS; } + bool operator!=(const MakeUniqueKind &RHS) const { return Kind != RHS; } +}; + REGISTER_MAP_WITH_PROGRAMSTATE(TrackedRegionMap, const MemRegion *, SVal) +REGISTER_LIST_WITH_PROGRAMSTATE(MakeUniqueKindList, MakeUniqueKindWrapper) // Define the inter-checker API. namespace clang { @@ -175,8 +200,35 @@ return CD && CD->getConversionType()->isBooleanType(); } +MakeUniqueKind +SmartPtrModeling::isStdMakeUniqueCall(const CallEvent &Call) const { + if (Call.getKind() != CallEventKind::CE_Function) + return MakeUniqueKind::None; + const auto *D = Call.getDecl(); + if (!D) + return MakeUniqueKind::None; + const auto *FTD = llvm::dyn_cast(D); + if (!FTD) + return MakeUniqueKind::None; + if (FTD->getDeclName().isIdentifier()) { + StringRef Name = FTD->getName(); + if (Name == "make_unique") + return MakeUniqueKind::Regular; + else if (Name == "make_unique_for_overwrite") + return MakeUniqueKind::ForOverwrite; + } + return MakeUniqueKind::None; +} + bool SmartPtrModeling::evalCall(const CallEvent &Call, CheckerContext &C) const { + + MakeUniqueKind Kind = isStdMakeUniqueCall(Call); + if (Kind != MakeUniqueKind::None) { + handleStdMakeUnique(Call, C, Kind); + return true; + } + ProgramStateRef State = C.getState(); if (!smartptr::isStdSmartPtrCall(Call)) return false; @@ -272,6 +324,59 @@ return C.isDifferent(); } +void SmartPtrModeling::handleStdMakeUnique(const CallEvent &Call, + CheckerContext &C, + MakeUniqueKind Kind) const { + assert(Kind != MakeUniqueKind::None && + "Call is not to make_unique or make_unique_for_overwrite"); + ProgramStateRef State = C.getState(); + State = State->add(Kind); + C.addTransition(State); +} + +bool isUniquePtrType(QualType QT) { + const auto *T = QT.getTypePtr(); + if (!T) + return false; + const auto *Decl = T->getAsCXXRecordDecl(); + if (!Decl || !Decl->getDeclContext()->isStdNamespace()) + return false; + const IdentifierInfo *ID = Decl->getIdentifier(); + if (!ID) + return false; + const StringRef Name = ID->getName(); + return Name == "unique_ptr"; +} + +void SmartPtrModeling::checkBind(SVal L, SVal V, const Stmt *S, + CheckerContext &C) const { + const auto *TVR = dyn_cast_or_null(L.getAsRegion()); + if (!TVR) + return; + if (!isUniquePtrType(TVR->getValueType())) + return; + const auto *ThisRegion = dyn_cast(TVR); + if (!ThisRegion) + return; + + ProgramStateRef State = C.getState(); + auto KindList = State->get(); + assert(!KindList.isEmpty() && + "Expected list to contain the kind of the last make_unique"); + auto Kind = KindList.getHead(); + assert(Kind != MakeUniqueKind::None && + "Bind is not to make_unique or make_unique_for_overwrite"); + if (Kind == MakeUniqueKind::ForOverwrite) { + auto RHSVal = C.getSValBuilder().makeNull(); + State = State->set(ThisRegion, RHSVal); + } else { + // TODO: Encode information that the inner pointer for + // unique_ptr made by std::make_unique is *not* null + } + State = State->set(KindList.getTail()); + C.addTransition(State); +} + void SmartPtrModeling::checkDeadSymbols(SymbolReaper &SymReaper, CheckerContext &C) const { ProgramStateRef State = C.getState(); diff --git a/polly/include/polly/CodeGen/IslAst.h b/polly/include/polly/CodeGen/IslAst.h --- a/polly/include/polly/CodeGen/IslAst.h +++ b/polly/include/polly/CodeGen/IslAst.h @@ -142,7 +142,7 @@ static bool isInnermost(const isl::ast_node &Node); /// Is this loop a parallel loop? - static bool isParallel(__isl_keep isl_ast_node *Node); + static bool isParallel(const isl::ast_node &Node); /// Is this loop an outermost parallel loop? static bool isOutermostParallel(const isl::ast_node &Node); @@ -151,20 +151,19 @@ static bool isInnermostParallel(const isl::ast_node &Node); /// Is this loop a reduction parallel loop? - static bool isReductionParallel(__isl_keep isl_ast_node *Node); + static bool isReductionParallel(const isl::ast_node &Node); /// Will the loop be run as thread parallel? - static bool isExecutedInParallel(__isl_keep isl_ast_node *Node); + static bool isExecutedInParallel(const isl::ast_node &Node); /// Get the nodes schedule or a nullptr if not available. - static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node); + static isl::union_map getSchedule(const isl::ast_node &Node); /// Get minimal dependence distance or nullptr if not available. - static __isl_give isl_pw_aff * - getMinimalDependenceDistance(__isl_keep isl_ast_node *Node); + static isl::pw_aff getMinimalDependenceDistance(const isl::ast_node &Node); /// Get the nodes broken reductions or a nullptr if not available. - static MemoryAccessSet *getBrokenReductions(__isl_keep isl_ast_node *Node); + static MemoryAccessSet *getBrokenReductions(const isl::ast_node &Node); /// Get the nodes build context or a nullptr if not available. static __isl_give isl_ast_build *getBuild(__isl_keep isl_ast_node *Node); diff --git a/polly/include/polly/CodeGen/IslNodeBuilder.h b/polly/include/polly/CodeGen/IslNodeBuilder.h --- a/polly/include/polly/CodeGen/IslNodeBuilder.h +++ b/polly/include/polly/CodeGen/IslNodeBuilder.h @@ -248,7 +248,7 @@ /// this subtree. /// @param Loops A vector that will be filled with the Loops referenced in /// this subtree. - void getReferencesInSubtree(__isl_keep isl_ast_node *For, + void getReferencesInSubtree(const isl::ast_node &For, SetVector &Values, SetVector &Loops); @@ -398,8 +398,7 @@ /// below this ast node to the scheduling vectors used to enumerate /// them. /// - virtual __isl_give isl_union_map * - getScheduleForAstNode(__isl_take isl_ast_node *Node); + virtual isl::union_map getScheduleForAstNode(const isl::ast_node &Node); private: /// Create code for a copy statement. diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp --- a/polly/lib/CodeGen/IslAst.cpp +++ b/polly/lib/CodeGen/IslAst.cpp @@ -140,7 +140,7 @@ } /// Return all broken reductions as a string of clauses (OpenMP style). -static const std::string getBrokenReductionsStr(__isl_keep isl_ast_node *Node) { +static const std::string getBrokenReductionsStr(const isl::ast_node &Node) { IslAstInfo::MemoryAccessSet *BrokenReductions; std::string str; @@ -171,25 +171,26 @@ static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, __isl_take isl_ast_print_options *Options, __isl_keep isl_ast_node *Node, void *) { - isl_pw_aff *DD = IslAstInfo::getMinimalDependenceDistance(Node); - const std::string BrokenReductionsStr = getBrokenReductionsStr(Node); + isl::pw_aff DD = + IslAstInfo::getMinimalDependenceDistance(isl::manage_copy(Node)); + const std::string BrokenReductionsStr = + getBrokenReductionsStr(isl::manage_copy(Node)); const std::string KnownParallelStr = "#pragma known-parallel"; const std::string DepDisPragmaStr = "#pragma minimal dependence distance: "; const std::string SimdPragmaStr = "#pragma simd"; const std::string OmpPragmaStr = "#pragma omp parallel for"; - if (DD) - Printer = printLine(Printer, DepDisPragmaStr, DD); + if (!DD.is_null()) + Printer = printLine(Printer, DepDisPragmaStr, DD.get()); if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node))) Printer = printLine(Printer, SimdPragmaStr + BrokenReductionsStr); - if (IslAstInfo::isExecutedInParallel(Node)) + if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node))) Printer = printLine(Printer, OmpPragmaStr); else if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node))) Printer = printLine(Printer, KnownParallelStr + BrokenReductionsStr); - isl_pw_aff_free(DD); return isl_ast_node_for_print(Node, Printer, Options); } @@ -472,15 +473,15 @@ switch (isl_ast_node_get_type(Node)) { case isl_ast_node_for: NumForLoops++; - if (IslAstInfo::isParallel(Node)) + if (IslAstInfo::isParallel(isl::manage_copy(Node))) NumParallel++; if (IslAstInfo::isInnermostParallel(isl::manage_copy(Node))) NumInnermostParallel++; if (IslAstInfo::isOutermostParallel(isl::manage_copy(Node))) NumOutermostParallel++; - if (IslAstInfo::isReductionParallel(Node)) + if (IslAstInfo::isReductionParallel(isl::manage_copy(Node))) NumReductionParallel++; - if (IslAstInfo::isExecutedInParallel(Node)) + if (IslAstInfo::isExecutedInParallel(isl::manage_copy(Node))) NumExecutedInParallel++; break; @@ -593,9 +594,9 @@ return Payload && Payload->IsInnermost; } -bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) { - return IslAstInfo::isInnermostParallel(isl::manage_copy(Node)) || - IslAstInfo::isOutermostParallel(isl::manage_copy(Node)); +bool IslAstInfo::isParallel(const isl::ast_node &Node) { + return IslAstInfo::isInnermostParallel(Node) || + IslAstInfo::isOutermostParallel(Node); } bool IslAstInfo::isInnermostParallel(const isl::ast_node &Node) { @@ -608,12 +609,12 @@ return Payload && Payload->IsOutermostParallel; } -bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) { - IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node)); +bool IslAstInfo::isReductionParallel(const isl::ast_node &Node) { + IslAstUserPayload *Payload = getNodePayload(Node); return Payload && Payload->IsReductionParallel; } -bool IslAstInfo::isExecutedInParallel(__isl_keep isl_ast_node *Node) { +bool IslAstInfo::isExecutedInParallel(const isl::ast_node &Node) { if (!PollyParallel) return false; @@ -626,28 +627,30 @@ // executed. This can possibly require run-time checks, which again // raises the question of both run-time check overhead and code size // costs. - if (!PollyParallelForce && isInnermost(isl::manage_copy(Node))) + if (!PollyParallelForce && isInnermost(Node)) return false; - return isOutermostParallel(isl::manage_copy(Node)) && - !isReductionParallel(Node); + return isOutermostParallel(Node) && !isReductionParallel(Node); } -__isl_give isl_union_map * -IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) { - IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node)); - return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr; +isl::union_map IslAstInfo::getSchedule(const isl::ast_node &Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + if (!Payload) + return nullptr; + + isl::ast_build Build = isl::manage_copy(Payload->Build); + return Build.get_schedule(); } -__isl_give isl_pw_aff * -IslAstInfo::getMinimalDependenceDistance(__isl_keep isl_ast_node *Node) { - IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node)); - return Payload ? Payload->MinimalDependenceDistance.copy() : nullptr; +isl::pw_aff +IslAstInfo::getMinimalDependenceDistance(const isl::ast_node &Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + return Payload ? Payload->MinimalDependenceDistance : nullptr; } IslAstInfo::MemoryAccessSet * -IslAstInfo::getBrokenReductions(__isl_keep isl_ast_node *Node) { - IslAstUserPayload *Payload = getNodePayload(isl::manage_copy(Node)); +IslAstInfo::getBrokenReductions(const isl::ast_node &Node) { + IslAstUserPayload *Payload = getNodePayload(Node); return Payload ? &Payload->BrokenReductions : nullptr; } diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp --- a/polly/lib/CodeGen/IslNodeBuilder.cpp +++ b/polly/lib/CodeGen/IslNodeBuilder.cpp @@ -300,12 +300,12 @@ addReferencesFromStmtSet(Set, &References); } -__isl_give isl_union_map * -IslNodeBuilder::getScheduleForAstNode(__isl_keep isl_ast_node *For) { - return IslAstInfo::getSchedule(For); +isl::union_map +IslNodeBuilder::getScheduleForAstNode(const isl::ast_node &Node) { + return IslAstInfo::getSchedule(Node); } -void IslNodeBuilder::getReferencesInSubtree(__isl_keep isl_ast_node *For, +void IslNodeBuilder::getReferencesInSubtree(const isl::ast_node &For, SetVector &Values, SetVector &Loops) { SetVector SCEVs; @@ -319,8 +319,7 @@ for (const auto &I : OutsideLoopIterations) Values.insert(cast(I.second)->getValue()); - isl::union_set Schedule = - isl::manage(isl_union_map_domain(getScheduleForAstNode(For))); + isl::union_set Schedule = getScheduleForAstNode(For).domain(); addReferencesFromStmtUnionSet(Schedule, References); for (const SCEV *Expr : SCEVs) { @@ -476,22 +475,22 @@ for (int i = 1; i < VectorWidth; i++) IVS[i] = Builder.CreateAdd(IVS[i - 1], ValueInc, "p_vector_iv"); - isl_union_map *Schedule = getScheduleForAstNode(For); - assert(Schedule && "For statement annotation does not contain its schedule"); + isl::union_map Schedule = getScheduleForAstNode(isl::manage_copy(For)); + assert(!Schedule.is_null() && + "For statement annotation does not contain its schedule"); IDToValue[IteratorID] = ValueLB; switch (isl_ast_node_get_type(Body)) { case isl_ast_node_user: - createUserVector(Body, IVS, isl_id_copy(IteratorID), - isl_union_map_copy(Schedule)); + createUserVector(Body, IVS, isl_id_copy(IteratorID), Schedule.copy()); break; case isl_ast_node_block: { isl_ast_node_list *List = isl_ast_node_block_get_children(Body); for (int i = 0; i < isl_ast_node_list_n_ast_node(List); ++i) createUserVector(isl_ast_node_list_get_ast_node(List, i), IVS, - isl_id_copy(IteratorID), isl_union_map_copy(Schedule)); + isl_id_copy(IteratorID), Schedule.copy()); isl_ast_node_free(Body); isl_ast_node_list_free(List); @@ -504,7 +503,6 @@ IDToValue.erase(IDToValue.find(IteratorID)); isl_id_free(IteratorID); - isl_union_map_free(Schedule); isl_ast_node_free(For); isl_ast_expr_free(Iterator); @@ -685,7 +683,7 @@ SetVector SubtreeValues; SetVector Loops; - getReferencesInSubtree(For, SubtreeValues, Loops); + getReferencesInSubtree(isl::manage_copy(For), SubtreeValues, Loops); // Create for all loops we depend on values that contain the current loop // iteration. These values are necessary to generate code for SCEVs that @@ -783,7 +781,7 @@ bool Vector = PollyVectorizerChoice == VECTORIZER_POLLY; if (Vector && IslAstInfo::isInnermostParallel(isl::manage_copy(For)) && - !IslAstInfo::isReductionParallel(For)) { + !IslAstInfo::isReductionParallel(isl::manage_copy(For))) { int VectorWidth = getNumberOfIterations(isl::manage_copy(For)); if (1 < VectorWidth && VectorWidth <= 16 && !hasPartialAccesses(For)) { createForVector(For, VectorWidth); @@ -791,12 +789,12 @@ } } - if (IslAstInfo::isExecutedInParallel(For)) { + if (IslAstInfo::isExecutedInParallel(isl::manage_copy(For))) { createForParallel(For); return; } - bool Parallel = - (IslAstInfo::isParallel(For) && !IslAstInfo::isReductionParallel(For)); + bool Parallel = (IslAstInfo::isParallel(isl::manage_copy(For)) && + !IslAstInfo::isReductionParallel(isl::manage_copy(For))); createForSequential(isl::manage(For), Parallel); }