diff --git a/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h b/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h --- a/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h +++ b/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h @@ -373,7 +373,7 @@ class FalsePositiveRefutationBRVisitor final : public BugReporterVisitor { private: /// Holds the constraints in a given path - ConstraintRangeTy Constraints; + ConstraintMap Constraints; public: FalsePositiveRefutationBRVisitor(); @@ -388,7 +388,6 @@ PathSensitiveBugReport &BR) override; }; - /// The visitor detects NoteTags and displays the event notes they contain. class TagVisitor : public BugReporterVisitor { public: diff --git a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h --- a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h +++ b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/RangedConstraintManager.h @@ -136,14 +136,8 @@ } }; -class ConstraintRange {}; -using ConstraintRangeTy = llvm::ImmutableMap; - -template <> -struct ProgramStateTrait - : public ProgramStatePartialTrait { - static void *GDMIndex(); -}; +using ConstraintMap = llvm::ImmutableMap; +ConstraintMap getConstraintMap(ProgramStateRef State); class RangedConstraintManager : public SimpleConstraintManager { public: @@ -222,4 +216,6 @@ } // namespace ento } // namespace clang +REGISTER_FACTORY_WITH_PROGRAMSTATE(ConstraintMap); + #endif diff --git a/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp b/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp --- a/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp +++ b/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp @@ -2813,7 +2813,7 @@ //===----------------------------------------------------------------------===// FalsePositiveRefutationBRVisitor::FalsePositiveRefutationBRVisitor() - : Constraints(ConstraintRangeTy::Factory().getEmptyMap()) {} + : Constraints(ConstraintMap::Factory().getEmptyMap()) {} void FalsePositiveRefutationBRVisitor::finalizeVisitor( BugReporterContext &BRC, const ExplodedNode *EndPathNode, @@ -2855,9 +2855,8 @@ PathDiagnosticPieceRef FalsePositiveRefutationBRVisitor::VisitNode( const ExplodedNode *N, BugReporterContext &, PathSensitiveBugReport &) { // Collect new constraints - const ConstraintRangeTy &NewCs = N->getState()->get(); - ConstraintRangeTy::Factory &CF = - N->getState()->get_context(); + ConstraintMap NewCs = getConstraintMap(N->getState()); + ConstraintMap::Factory &CF = N->getState()->get_context(); // Add constraints if we don't have them yet for (auto const &C : NewCs) { diff --git a/clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp b/clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp --- a/clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp +++ b/clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp @@ -391,7 +391,191 @@ os << " }"; } +REGISTER_SET_FACTORY_WITH_PROGRAMSTATE(SymbolSet, SymbolRef) + +namespace { +class EquivalenceClass; +} // end anonymous namespace + +REGISTER_MAP_WITH_PROGRAMSTATE(ClassMap, SymbolRef, EquivalenceClass) +REGISTER_MAP_WITH_PROGRAMSTATE(ClassMembers, EquivalenceClass, SymbolSet) +REGISTER_MAP_WITH_PROGRAMSTATE(ConstraintRange, EquivalenceClass, RangeSet) + namespace { +/// This class encapsulates a set of symbols equal to each other. +/// +/// The main idea of the approach requiring such classes is in narrowing +/// and sharing constraints between symbols within the class. Also we can +/// conclude that there is no practical need in storing constraints for +/// every member of the class separately. +/// +/// Main terminology: +/// +/// * "Equivalence class" is an object of this class, which can be efficiently +/// compared to other classes. It represents the whole class without +/// storing the actual in it. The members of the class however can be +/// retrieved from the state. +/// +/// * "Class members" are the symbols corresponding to the class. This means +/// that A == B for every member symbols A and B from the class. Members of +/// each class are stored in the state. +/// +/// * "Trivial class" is a class that has and ever had only one same symbol. +/// +/// * "Merge (or Union) operation" merges two classes into one. It is the +/// main operation to produce non-trivial classes. +/// If, at some point, we can assume that two symbols from two distinct +/// classes are equal, we can merge these classes. +class EquivalenceClass : public llvm::FoldingSetNode { +public: + /// Find equivalence class for the given symbol in the given state. + static inline EquivalenceClass find(ProgramStateRef State, SymbolRef Sym); + + /// Merge classes for the given symbols and return a new state. + static inline ProgramStateRef merge(BasicValueFactory &BV, + RangeSet::Factory &F, + ProgramStateRef State, SymbolRef First, + SymbolRef Second); + // Merge this class with the given class and return a new state. + inline ProgramStateRef merge(BasicValueFactory &BV, RangeSet::Factory &F, + ProgramStateRef State, EquivalenceClass Other); + + /// Return a set of class members for the given state. + inline SymbolSet getClassMembers(ProgramStateRef State); + /// Return true if the current class is trivial in the given state. + inline bool isTrivial(ProgramStateRef State); + /// Return true if the current class is trivial and its only member is dead. + inline bool isTriviallyDead(ProgramStateRef State, SymbolReaper &Reaper); + + EquivalenceClass() = delete; + EquivalenceClass(const EquivalenceClass &) = default; + EquivalenceClass &operator=(const EquivalenceClass &) = default; + EquivalenceClass(EquivalenceClass &&) = default; + EquivalenceClass &operator=(EquivalenceClass &&) = default; + + bool operator==(const EquivalenceClass &Other) const { + return ID == Other.ID; + } + bool operator<(const EquivalenceClass &Other) const { return ID < Other.ID; } + bool operator!=(const EquivalenceClass &Other) const { + return !operator==(Other); + } + + static void Profile(llvm::FoldingSetNodeID &ID, uintptr_t CID) { + ID.AddInteger(CID); + } + + void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, this->ID); } + +private: + /* implicit */ EquivalenceClass(SymbolRef Sym) + : ID(reinterpret_cast(Sym)) {} + + /// This function is intended to be used ONLY within the class. + /// The fact that ID is a pointer to a symbol is an implementation detail + /// and should stay that way. + /// In the current implementation, we use it to retrieve the only member + /// of the trivial class. + SymbolRef getRepresentativeSymbol() const { + return reinterpret_cast(ID); + } + static inline SymbolSet::Factory &getMembersFactory(ProgramStateRef State); + + inline ProgramStateRef mergeImpl(BasicValueFactory &BV, RangeSet::Factory &F, + ProgramStateRef State, SymbolSet Members, + EquivalenceClass Other, + SymbolSet OtherMembers); + + /// This is a unique identifier of the class. + uintptr_t ID; +}; + +inline bool isZero(const llvm::APSInt &Int) { + APSIntType Type(Int); + return Int == Type.getZeroValue(); +} + +//===----------------------------------------------------------------------===// +// Constraint functions +//===----------------------------------------------------------------------===// + +LLVM_NODISCARD inline ProgramStateRef setConstraint(ProgramStateRef State, + EquivalenceClass Class, + RangeSet Constraint) { + return State->set(Class, Constraint); +} + +LLVM_NODISCARD inline ProgramStateRef +setConstraint(ProgramStateRef State, SymbolRef Sym, RangeSet Constraint) { + return setConstraint(State, EquivalenceClass::find(State, Sym), Constraint); +} + +LLVM_NODISCARD inline const RangeSet *getConstraint(ProgramStateRef State, + EquivalenceClass Class) { + return State->get(Class); +} + +LLVM_NODISCARD inline const RangeSet *getConstraint(ProgramStateRef State, + SymbolRef Sym) { + return getConstraint(State, EquivalenceClass::find(State, Sym)); +} + +//===----------------------------------------------------------------------===// +// Equality tracker +//===----------------------------------------------------------------------===// + +/// A small helper structure representing symbolic equality. +/// +/// Equality check can have different forms (like a == b or a - b) and this +/// class encapsulates those away if the only thing the user wants to check - +/// whether it's equality/diseqiality or not and have an easy access to the +/// compared symbols. +struct EqualityInfo { +public: + SymbolRef Left, Right; + // true for equality and false for disequality. + bool IsEquality = true; + + void invert() { IsEquality = !IsEquality; } + /// Extract equality information from the given symbol and the constants. + /// + /// This function assumes the following expression Sym + Adjustment != Int. + /// It is a default because the most widespread case of the equality check + /// is (A == B) + 0 != 0. + static Optional extract(SymbolRef Sym, const llvm::APSInt &Int, + const llvm::APSInt &Adjustment) { + // As of now, the only equality form supported is Sym + 0 != 0. + if (!isZero(Int) || !isZero(Adjustment)) + return llvm::None; + + return extract(Sym); + } + /// Extract equality information from the given symbol. + static Optional extract(SymbolRef Sym) { + return EqualityExtractor().Visit(Sym); + } + +private: + class EqualityExtractor + : public SymExprVisitor> { + public: + Optional VisitSymSymExpr(const SymSymExpr *Sym) const { + switch (Sym->getOpcode()) { + case BO_Sub: + // This case is: A - B != 0 -> disequality check. + return EqualityInfo{Sym->getLHS(), Sym->getRHS(), false}; + case BO_EQ: + // This case is: A == B != 0 -> equality check. + return EqualityInfo{Sym->getLHS(), Sym->getRHS(), true}; + case BO_NE: + // This case is: A != B != 0 -> diseqiality check. + return EqualityInfo{Sym->getLHS(), Sym->getRHS(), false}; + default: + return llvm::None; + } + } + }; +}; //===----------------------------------------------------------------------===// // Intersection functions @@ -556,15 +740,16 @@ RangeSet infer(SymbolRef Sym) { if (Optional ConstraintBasedRange = intersect( - ValueFactory, RangeFactory, State->get(Sym), + ValueFactory, RangeFactory, getConstraint(State, Sym), // If Sym is a difference of symbols A - B, then maybe we have range // set stored for B - A. // // If we have range set stored for both A - B and B - A then // calculate the effective range set by intersecting the range set // for A - B and the negated range set of B - A. - getRangeForInvertedSub(Sym))) + getRangeForInvertedSub(Sym), getRangeForEqualities(Sym))) { return *ConstraintBasedRange; + } // If Sym is a comparison expression (except <=>), // find any other comparisons with the same operands. @@ -745,8 +930,7 @@ SymbolRef NegatedSym = SymMgr.getSymSymExpr(SSE->getRHS(), BO_Sub, SSE->getLHS(), T); - if (const RangeSet *NegatedRange = - State->get(NegatedSym)) { + if (const RangeSet *NegatedRange = getConstraint(State, NegatedSym)) { return NegatedRange->Negate(ValueFactory, RangeFactory); } } @@ -792,7 +976,7 @@ // Let's find an expression e.g. (x < y). BinaryOperatorKind QueriedOP = OperatorRelationsTable::getOpFromIndex(i); const SymSymExpr *SymSym = SymMgr.getSymSymExpr(LHS, QueriedOP, RHS, T); - const RangeSet *QueriedRangeSet = State->get(SymSym); + const RangeSet *QueriedRangeSet = getConstraint(State, SymSym); // If ranges were not previously found, // try to find a reversed expression (y > x). @@ -800,7 +984,7 @@ const BinaryOperatorKind ROP = BinaryOperator::reverseComparisonOp(QueriedOP); SymSym = SymMgr.getSymSymExpr(RHS, ROP, LHS, T); - QueriedRangeSet = State->get(SymSym); + QueriedRangeSet = getConstraint(State, SymSym); } if (!QueriedRangeSet || QueriedRangeSet->isEmpty()) @@ -838,6 +1022,27 @@ return llvm::None; } + Optional getRangeForEqualities(SymbolRef Sym) { + Optional Equality = EqualityInfo::extract(Sym); + + if (!Equality) + return llvm::None; + + EquivalenceClass LHS = EquivalenceClass::find(State, Equality->Left); + EquivalenceClass RHS = EquivalenceClass::find(State, Equality->Right); + + if (LHS != RHS) + // Can't really say anything at this point. + // We can add more logic here if we track disequalities as well. + return llvm::None; + + // At this point, operands of the equality operation are known to be equal. + if (Equality->IsEquality) { + return getTrueRange(Sym->getType()); + } + return getFalseRange(Sym->getType()); + } + RangeSet getTrueRange(QualType T) { RangeSet TypeRange = infer(T); return assumeNonZero(TypeRange, T); @@ -1032,7 +1237,11 @@ bool haveEqualConstraints(ProgramStateRef S1, ProgramStateRef S2) const override { - return S1->get() == S2->get(); + // NOTE: ClassMembers are as simple as back pointers for ClassMap, + // so comparing constraint ranges and class maps should be + // sufficient. + return S1->get() == S2->get() && + S1->get() == S2->get(); } bool canReasonAbout(SVal X) const override; @@ -1104,6 +1313,49 @@ RangeSet getSymGERange(ProgramStateRef St, SymbolRef Sym, const llvm::APSInt &Int, const llvm::APSInt &Adjustment); + + //===------------------------------------------------------------------===// + // Equality tracking implementation + //===------------------------------------------------------------------===// + + ProgramStateRef trackEQ(ProgramStateRef State, SymbolRef Sym, + const llvm::APSInt &Int, + const llvm::APSInt &Adjustment) { + if (auto Equality = EqualityInfo::extract(Sym, Int, Adjustment)) { + // Extract function assumes that we gave it Sym + Adjustment != Int, + // so the result should be opposite. + Equality->invert(); + return track(State, *Equality); + } + + return State; + } + + ProgramStateRef trackNE(ProgramStateRef State, SymbolRef Sym, + const llvm::APSInt &Int, + const llvm::APSInt &Adjustment) { + if (auto Equality = EqualityInfo::extract(Sym, Int, Adjustment)) { + return track(State, *Equality); + } + + return State; + } + + ProgramStateRef track(ProgramStateRef State, EqualityInfo ToTrack) { + if (ToTrack.IsEquality) { + return trackEquality(State, ToTrack.Left, ToTrack.Right); + } + return trackDisequality(State, ToTrack.Left, ToTrack.Right); + } + + ProgramStateRef trackDisequality(ProgramStateRef State, SymbolRef LHS, + SymbolRef RHS) { + // TODO: track inequalities + return State; + } + + ProgramStateRef trackEquality(ProgramStateRef State, SymbolRef LHS, + SymbolRef RHS); }; } // end anonymous namespace @@ -1114,6 +1366,153 @@ return std::make_unique(Eng, StMgr.getSValBuilder()); } +ConstraintMap ento::getConstraintMap(ProgramStateRef State) { + ConstraintMap::Factory &F = State->get_context(); + ConstraintMap Result = F.getEmptyMap(); + + ConstraintRangeTy Constraints = State->get(); + for (std::pair ClassConstraint : Constraints) { + EquivalenceClass Class = ClassConstraint.first; + SymbolSet ClassMembers = Class.getClassMembers(State); + assert(!ClassMembers.isEmpty() && + "Class must always have at least one member!"); + + SymbolRef Representative = *ClassMembers.begin(); + Result = F.add(Result, Representative, ClassConstraint.second); + } + + return Result; +} + +//===----------------------------------------------------------------------===// +// EqualityClass implementation details +//===----------------------------------------------------------------------===// + +inline EquivalenceClass EquivalenceClass::find(ProgramStateRef State, + SymbolRef Sym) { + if (const EquivalenceClass *NontrivialClass = State->get(Sym)) + return *NontrivialClass; + + // This is a trivial class of Sym. + return Sym; +} + +inline ProgramStateRef EquivalenceClass::merge(BasicValueFactory &BV, + RangeSet::Factory &F, + ProgramStateRef State, + SymbolRef First, + SymbolRef Second) { + EquivalenceClass FirstClass = find(State, First); + EquivalenceClass SecondClass = find(State, Second); + + return FirstClass.merge(BV, F, State, SecondClass); +} + +inline ProgramStateRef EquivalenceClass::merge(BasicValueFactory &BV, + RangeSet::Factory &F, + ProgramStateRef State, + EquivalenceClass Other) { + // It is already the same class. + if (*this == Other) + return State; + + SymbolSet Members = getClassMembers(State); + SymbolSet OtherMembers = Other.getClassMembers(State); + + // We estimate the size of the class by the height of tree containing + // its members. Merging is not a trivial operation, so it's easier to + // merge the smaller class into the bigger one. + if (Members.getHeight() >= OtherMembers.getHeight()) { + return mergeImpl(BV, F, State, Members, Other, OtherMembers); + } else { + return Other.mergeImpl(BV, F, State, OtherMembers, *this, Members); + } +} + +inline ProgramStateRef +EquivalenceClass::mergeImpl(BasicValueFactory &ValueFactory, + RangeSet::Factory &RangeFactory, + ProgramStateRef State, SymbolSet MyMembers, + EquivalenceClass Other, SymbolSet OtherMembers) { + // 1. Get ALL constraint- and equivalence-related maps + ClassMapTy Classes = State->get(); + ClassMapTy::Factory &CF = State->get_context(); + + ClassMembersTy Members = State->get(); + ClassMembersTy::Factory &MF = State->get_context(); + + ConstraintRangeTy Constraints = State->get(); + ConstraintRangeTy::Factory &CRF = State->get_context(); + + SymbolSet::Factory &F = getMembersFactory(State); + + // 2. Merge members of the Other class into the current class. + SymbolSet NewClassMembers = MyMembers; + for (SymbolRef Sym : OtherMembers) { + NewClassMembers = F.add(NewClassMembers, Sym); + // *this is now the class for all these new symbols. + Classes = CF.add(Classes, Sym, *this); + } + + // 3. Adjust member mapping. + // + // No need in tracking members of a now-dissolved class. + Members = MF.remove(Members, Other); + // Now only the current class is mapped to all the symbols. + Members = MF.add(Members, *this, NewClassMembers); + + // 4. Update the state + State = State->set(Classes); + State = State->set(Members); + + // 5. If the merged classes have any constraints associated with them, we + // need to transfer them to the class we have left. + // + // Intersection here makes perfect sense because both of these constraints + // must hold for the whole new class. + if (Optional NewClassConstraint = + intersect(ValueFactory, RangeFactory, getConstraint(State, *this), + getConstraint(State, Other))) { + // NOTE: Essentially, NewClassConstraint should NEVER be infeasible because + // we shouldn't make assumptions that can lead to that. + // However, at the moment, due to imperfections in the solver, it is + // possible. + // + // No need in tracking constraints of a now-dissolved class. + Constraints = CRF.remove(Constraints, Other); + // Assign new constraints for this class. + Constraints = CRF.add(Constraints, *this, *NewClassConstraint); + + State = State->set(Constraints); + } + + return State; +} + +inline SymbolSet::Factory & +EquivalenceClass::getMembersFactory(ProgramStateRef State) { + return State->get_context(); +} + +SymbolSet EquivalenceClass::getClassMembers(ProgramStateRef State) { + if (const SymbolSet *Members = State->get(*this)) + return *Members; + + // This class is trivial, so we need to construct a set + // with just that one symbol from the class. + SymbolSet::Factory &F = getMembersFactory(State); + return F.add(F.getEmptySet(), getRepresentativeSymbol()); +} + +bool EquivalenceClass::isTrivial(ProgramStateRef State) { + return State->get(*this) == nullptr; +} + +bool EquivalenceClass::isTriviallyDead(ProgramStateRef State, + SymbolReaper &Reaper) { + return isTrivial(State) && Reaper.isDead(getRepresentativeSymbol()); +} + //===----------------------------------------------------------------------===// // RangeConstraintManager implementation //===----------------------------------------------------------------------===// @@ -1166,7 +1565,7 @@ ConditionTruthVal RangeConstraintManager::checkNull(ProgramStateRef State, SymbolRef Sym) { - const RangeSet *Ranges = State->get(Sym); + const RangeSet *Ranges = getConstraint(State, Sym); // If we don't have any information about this symbol, it's underconstrained. if (!Ranges) @@ -1190,7 +1589,7 @@ const llvm::APSInt *RangeConstraintManager::getSymVal(ProgramStateRef St, SymbolRef Sym) const { - const ConstraintRangeTy::data_type *T = St->get(Sym); + const RangeSet *T = getConstraint(St, Sym); return T ? T->getConcreteValue() : nullptr; } @@ -1203,19 +1602,94 @@ ProgramStateRef RangeConstraintManager::removeDeadBindings(ProgramStateRef State, SymbolReaper &SymReaper) { - bool Changed = false; - ConstraintRangeTy CR = State->get(); - ConstraintRangeTy::Factory &CRFactory = State->get_context(); + ClassMembersTy ClassMembersMap = State->get(); + ClassMembersTy NewClassMembersMap = ClassMembersMap; + ClassMembersTy::Factory &EMFactory = State->get_context(); + SymbolSet::Factory &SetFactory = State->get_context(); + + ConstraintRangeTy Constraints = State->get(); + ConstraintRangeTy NewConstraints = Constraints; + ConstraintRangeTy::Factory &ConstraintFactory = + State->get_context(); + + ClassMapTy Map = State->get(); + ClassMapTy NewMap = Map; + ClassMapTy::Factory &ClassFactory = State->get_context(); + + bool ClassMapChanged = false; + bool MembersMapChanged = false; + bool ConstraintMapChanged = false; + + // 1. Let's see if dead symbols are trivial and have associated constraints. + for (std::pair ClassConstraintPair : + Constraints) { + EquivalenceClass Class = ClassConstraintPair.first; + if (Class.isTriviallyDead(State, SymReaper)) { + // If this class is trivial, we can remove its constraints right away. + Constraints = ConstraintFactory.remove(Constraints, Class); + ConstraintMapChanged = true; + } + } + + // 2. We don't need to track classes for dead symbols. + for (std::pair SymbolClassPair : Map) { + SymbolRef Sym = SymbolClassPair.first; - for (ConstraintRangeTy::iterator I = CR.begin(), E = CR.end(); I != E; ++I) { - SymbolRef Sym = I.getKey(); if (SymReaper.isDead(Sym)) { - Changed = true; - CR = CRFactory.remove(CR, Sym); + ClassMapChanged = true; + NewMap = ClassFactory.remove(NewMap, Sym); + } + } + + // 3. Remove dead members from classes and remove dead non-trivial classes + // and their constraints. + for (std::pair ClassMembersPair : + ClassMembersMap) { + SymbolSet LiveMembers = ClassMembersPair.second; + bool MembersChanged = false; + + for (SymbolRef Member : ClassMembersPair.second) { + if (SymReaper.isDead(Member)) { + MembersChanged = true; + LiveMembers = SetFactory.remove(LiveMembers, Member); + } + } + + // Check if the class changed. + if (!MembersChanged) + continue; + + MembersMapChanged = true; + + if (LiveMembers.isEmpty()) { + // The class is dead now, we need to wipe it out of the members map... + NewClassMembersMap = + EMFactory.remove(NewClassMembersMap, ClassMembersPair.first); + + // ...and remove all of its constraints. + Constraints = + ConstraintFactory.remove(Constraints, ClassMembersPair.first); + ConstraintMapChanged = true; + } else { + // We need to change the members associated with the class. + NewClassMembersMap = EMFactory.add(NewClassMembersMap, + ClassMembersPair.first, LiveMembers); } } - return Changed ? State->set(CR) : State; + // 4. Update the state with new maps. + // + // Here we try to be humble and update a map only if it really changed. + if (ClassMapChanged) + State = State->set(NewMap); + + if (MembersMapChanged) + State = State->set(NewClassMembersMap); + + if (ConstraintMapChanged) + State = State->set(Constraints); + + return State; } RangeSet RangeConstraintManager::getRange(ProgramStateRef State, @@ -1247,7 +1721,13 @@ llvm::APSInt Point = AdjustmentType.convert(Int) - Adjustment; RangeSet New = getRange(St, Sym).Delete(getBasicVals(), F, Point); - return New.isEmpty() ? nullptr : St->set(Sym, New); + + if (New.isEmpty()) + // this is infeasible assumption + return nullptr; + + ProgramStateRef NewState = setConstraint(St, Sym, New); + return trackNE(NewState, Sym, Int, Adjustment); } ProgramStateRef @@ -1262,7 +1742,13 @@ // [Int-Adjustment, Int-Adjustment] llvm::APSInt AdjInt = AdjustmentType.convert(Int) - Adjustment; RangeSet New = getRange(St, Sym).Intersect(getBasicVals(), F, AdjInt, AdjInt); - return New.isEmpty() ? nullptr : St->set(Sym, New); + + if (New.isEmpty()) + // this is infeasible assumption + return nullptr; + + ProgramStateRef NewState = setConstraint(St, Sym, New); + return trackEQ(NewState, Sym, Int, Adjustment); } RangeSet RangeConstraintManager::getSymLTRange(ProgramStateRef St, @@ -1298,7 +1784,7 @@ const llvm::APSInt &Int, const llvm::APSInt &Adjustment) { RangeSet New = getSymLTRange(St, Sym, Int, Adjustment); - return New.isEmpty() ? nullptr : St->set(Sym, New); + return New.isEmpty() ? nullptr : setConstraint(St, Sym, New); } RangeSet RangeConstraintManager::getSymGTRange(ProgramStateRef St, @@ -1334,7 +1820,7 @@ const llvm::APSInt &Int, const llvm::APSInt &Adjustment) { RangeSet New = getSymGTRange(St, Sym, Int, Adjustment); - return New.isEmpty() ? nullptr : St->set(Sym, New); + return New.isEmpty() ? nullptr : setConstraint(St, Sym, New); } RangeSet RangeConstraintManager::getSymGERange(ProgramStateRef St, @@ -1370,13 +1856,13 @@ const llvm::APSInt &Int, const llvm::APSInt &Adjustment) { RangeSet New = getSymGERange(St, Sym, Int, Adjustment); - return New.isEmpty() ? nullptr : St->set(Sym, New); + return New.isEmpty() ? nullptr : setConstraint(St, Sym, New); } -RangeSet RangeConstraintManager::getSymLERange( - llvm::function_ref RS, - const llvm::APSInt &Int, - const llvm::APSInt &Adjustment) { +RangeSet +RangeConstraintManager::getSymLERange(llvm::function_ref RS, + const llvm::APSInt &Int, + const llvm::APSInt &Adjustment) { // Before we do any real work, see if the value can even show up. APSIntType AdjustmentType(Adjustment); switch (AdjustmentType.testInRange(Int, true)) { @@ -1413,7 +1899,7 @@ const llvm::APSInt &Int, const llvm::APSInt &Adjustment) { RangeSet New = getSymLERange(St, Sym, Int, Adjustment); - return New.isEmpty() ? nullptr : St->set(Sym, New); + return New.isEmpty() ? nullptr : setConstraint(St, Sym, New); } ProgramStateRef RangeConstraintManager::assumeSymWithinInclusiveRange( @@ -1423,7 +1909,7 @@ if (New.isEmpty()) return nullptr; RangeSet Out = getSymLERange([&] { return New; }, To, Adjustment); - return Out.isEmpty() ? nullptr : State->set(Sym, Out); + return Out.isEmpty() ? nullptr : setConstraint(State, Sym, Out); } ProgramStateRef RangeConstraintManager::assumeSymOutsideInclusiveRange( @@ -1432,7 +1918,13 @@ RangeSet RangeLT = getSymLTRange(State, Sym, From, Adjustment); RangeSet RangeGT = getSymGTRange(State, Sym, To, Adjustment); RangeSet New(RangeLT.addRange(F, RangeGT)); - return New.isEmpty() ? nullptr : State->set(Sym, New); + return New.isEmpty() ? nullptr : setConstraint(State, Sym, New); +} + +ProgramStateRef RangeConstraintManager::trackEquality(ProgramStateRef State, + SymbolRef LHS, + SymbolRef RHS) { + return EquivalenceClass::merge(getBasicVals(), F, State, LHS, RHS); } //===----------------------------------------------------------------------===// @@ -1452,17 +1944,25 @@ ++Space; Out << '[' << NL; - for (ConstraintRangeTy::iterator I = Constraints.begin(); - I != Constraints.end(); ++I) { - Indent(Out, Space, IsDot) - << "{ \"symbol\": \"" << I.getKey() << "\", \"range\": \""; - I.getData().print(Out); - Out << "\" }"; - - if (std::next(I) != Constraints.end()) - Out << ','; - Out << NL; + bool First = true; + for (std::pair P : Constraints) { + SymbolSet ClassMembers = P.first.getClassMembers(State); + + // We can print the same constraint for every class member. + for (SymbolRef ClassMember : ClassMembers) { + if (First) { + First = false; + } else { + Out << ','; + Out << NL; + } + Indent(Out, Space, IsDot) + << "{ \"symbol\": \"" << ClassMember << "\", \"range\": \""; + P.second.print(Out); + Out << "\" }"; + } } + Out << NL; --Space; Indent(Out, Space, IsDot) << "]," << NL; diff --git a/clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp b/clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp --- a/clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp +++ b/clang/lib/StaticAnalyzer/Core/RangedConstraintManager.cpp @@ -40,19 +40,20 @@ } } else if (const SymSymExpr *SSE = dyn_cast(Sym)) { - // Translate "a != b" to "(b - a) != 0". - // We invert the order of the operands as a heuristic for how loop - // conditions are usually written ("begin != end") as compared to length - // calculations ("end - begin"). The more correct thing to do would be to - // canonicalize "a - b" and "b - a", which would allow us to treat - // "a != b" and "b != a" the same. - SymbolManager &SymMgr = getSymbolManager(); BinaryOperator::Opcode Op = SSE->getOpcode(); assert(BinaryOperator::isComparisonOp(Op)); - // For now, we only support comparing pointers. + // We convert equality operations for pointers only. if (Loc::isLocType(SSE->getLHS()->getType()) && Loc::isLocType(SSE->getRHS()->getType())) { + // Translate "a != b" to "(b - a) != 0". + // We invert the order of the operands as a heuristic for how loop + // conditions are usually written ("begin != end") as compared to length + // calculations ("end - begin"). The more correct thing to do would be to + // canonicalize "a - b" and "b - a", which would allow us to treat + // "a != b" and "b != a" the same. + + SymbolManager &SymMgr = getSymbolManager(); QualType DiffTy = SymMgr.getContext().getPointerDiffType(); SymbolRef Subtraction = SymMgr.getSymSymExpr(SSE->getRHS(), BO_Sub, SSE->getLHS(), DiffTy); @@ -63,6 +64,25 @@ Op = BinaryOperator::negateComparisonOp(Op); return assumeSymRel(State, Subtraction, Op, Zero); } + + if (BinaryOperator::isEqualityOp(Op)) { + SymbolManager &SymMgr = getSymbolManager(); + + QualType ExprType = SSE->getType(); + SymbolRef CanonicalEquality = + SymMgr.getSymSymExpr(SSE->getLHS(), BO_EQ, SSE->getRHS(), ExprType); + + bool WasEqual = SSE->getOpcode() == BO_EQ; + bool IsExpectedEqual = WasEqual == Assumption; + + const llvm::APSInt &Zero = getBasicVals().getValue(0, ExprType); + + if (IsExpectedEqual) { + return assumeSymNE(State, CanonicalEquality, Zero, Zero); + } + + return assumeSymEQ(State, CanonicalEquality, Zero, Zero); + } } // If we get here, there's nothing else we can do but treat the symbol as @@ -199,11 +219,6 @@ } } -void *ProgramStateTrait::GDMIndex() { - static int Index; - return &Index; -} - } // end of namespace ento } // end of namespace clang diff --git a/clang/test/Analysis/equality_tracking.c b/clang/test/Analysis/equality_tracking.c new file mode 100644 --- /dev/null +++ b/clang/test/Analysis/equality_tracking.c @@ -0,0 +1,132 @@ +// RUN: %clang_analyze_cc1 -verify %s \ +// RUN: -analyzer-checker=core,debug.ExprInspection \ +// RUN: -analyzer-config eagerly-assume=false + +#define NULL (void *)0 + +#define UCHAR_MAX (unsigned char)(~0U) +#define CHAR_MAX (char)(UCHAR_MAX & (UCHAR_MAX >> 1)) +#define CHAR_MIN (char)(UCHAR_MAX & ~(UCHAR_MAX >> 1)) + +void clang_analyzer_eval(int); + +int getInt(); + +void zeroImpliesEquality(int a, int b) { + clang_analyzer_eval((a - b) == 0); // expected-warning{{UNKNOWN}} + if ((a - b) == 0) { + clang_analyzer_eval(b != a); // expected-warning{{FALSE}} + clang_analyzer_eval(b == a); // expected-warning{{TRUE}} + clang_analyzer_eval(!(a != b)); // expected-warning{{TRUE}} + clang_analyzer_eval(!(b == a)); // expected-warning{{FALSE}} + return; + } + clang_analyzer_eval((a - b) == 0); // expected-warning{{FALSE}} + // FIXME: we should track disequality information as well + clang_analyzer_eval(b == a); // expected-warning{{UNKNOWN}} + clang_analyzer_eval(b != a); // expected-warning{{UNKNOWN}} +} + +void zeroImpliesReversedEqual(int a, int b) { + clang_analyzer_eval((b - a) == 0); // expected-warning{{UNKNOWN}} + if ((b - a) == 0) { + clang_analyzer_eval(b != a); // expected-warning{{FALSE}} + clang_analyzer_eval(b == a); // expected-warning{{TRUE}} + return; + } + clang_analyzer_eval((b - a) == 0); // expected-warning{{FALSE}} + // FIXME: we should track disequality information as well + clang_analyzer_eval(b == a); // expected-warning{{UNKNOWN}} + clang_analyzer_eval(b != a); // expected-warning{{UNKNOWN}} +} + +void canonicalEqual(int a, int b) { + clang_analyzer_eval(a == b); // expected-warning{{UNKNOWN}} + if (a == b) { + clang_analyzer_eval(b == a); // expected-warning{{TRUE}} + return; + } + clang_analyzer_eval(a == b); // expected-warning{{FALSE}} + clang_analyzer_eval(b == a); // expected-warning{{FALSE}} +} + +void test(int a, int b, int c, int d) { + if (a == b && c == d) { + if (a == 0 && b == d) { + clang_analyzer_eval(c == 0); // expected-warning{{TRUE}} + } + c = 10; + if (b == d) { + clang_analyzer_eval(c == 10); // expected-warning{{TRUE}} + clang_analyzer_eval(d == 10); // expected-warning{{UNKNOWN}} + // expected-warning@-1{{FALSE}} + clang_analyzer_eval(b == a); // expected-warning{{TRUE}} + clang_analyzer_eval(a == d); // expected-warning{{TRUE}} + + b = getInt(); + clang_analyzer_eval(a == d); // expected-warning{{TRUE}} + clang_analyzer_eval(a == b); // expected-warning{{UNKNOWN}} + } + } + + if (a != b && b == c) { + if (c == 42) { + clang_analyzer_eval(b == 42); // expected-warning{{TRUE}} + // FIXME: we should track disequality information as well + clang_analyzer_eval(a != 42); // expected-warning{{UNKNOWN}} + } + } +} + +void testIntersection(int a, int b, int c) { + if (a < 42 && b > 15 && c >= 25 && c <= 30) { + if (a != b) + return; + + clang_analyzer_eval(a > 15); // expected-warning{{TRUE}} + clang_analyzer_eval(b < 42); // expected-warning{{TRUE}} + clang_analyzer_eval(a <= 30); // expected-warning{{UNKNOWN}} + + if (c == b) { + // For all equal symbols, we should track the minimal common range. + // + // Also, it should be noted that c is dead at this point, but the + // constraint initially associated with c is still around. + clang_analyzer_eval(a >= 25 && a <= 30); // expected-warning{{TRUE}} + clang_analyzer_eval(b >= 25 && b <= 30); // expected-warning{{TRUE}} + } + } +} + +void testPromotion(int a, char b) { + if (b > 10) { + if (a == b) { + clang_analyzer_eval(a > 10); // expected-warning{{TRUE}} + clang_analyzer_eval(a <= CHAR_MAX); // expected-warning{{TRUE}} + } + } +} + +void testPromotionOnlyTypes(int a, char b) { + if (a == b) { + // FIXME: even when b doesn't have any constraints we still + // should understand that b has a smaller type and assign + // constraints correspondingly + clang_analyzer_eval(a <= CHAR_MAX); // expected-warning{{UNKNOWN}} + } +} + +void testPointers(int *a, int *b, int *c, int *d) { + if (a == b && c == d) { + if (a == NULL && b == d) { + clang_analyzer_eval(c == NULL); // expected-warning{{TRUE}} + } + } + + if (a != b && b == c) { + if (c == NULL) { + // FIXME: we should track disequality information as well + clang_analyzer_eval(a != NULL); // expected-warning{{UNKNOWN}} + } + } +}