diff --git a/clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp b/clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp --- a/clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp +++ b/clang/lib/StaticAnalyzer/Core/RangeConstraintManager.cpp @@ -401,6 +401,9 @@ REGISTER_MAP_WITH_PROGRAMSTATE(ClassMembers, EquivalenceClass, SymbolSet) REGISTER_MAP_WITH_PROGRAMSTATE(ConstraintRange, EquivalenceClass, RangeSet) +REGISTER_SET_FACTORY_WITH_PROGRAMSTATE(ClassSet, EquivalenceClass) +REGISTER_MAP_WITH_PROGRAMSTATE(DisequalityMap, EquivalenceClass, ClassSet) + namespace { /// This class encapsulates a set of symbols equal to each other. /// @@ -450,6 +453,24 @@ LLVM_NODISCARD inline bool isTriviallyDead(ProgramStateRef State, SymbolReaper &Reaper); + LLVM_NODISCARD static inline ProgramStateRef + markDisequal(BasicValueFactory &BV, RangeSet::Factory &F, + ProgramStateRef State, SymbolRef First, SymbolRef Second); + LLVM_NODISCARD static inline ProgramStateRef + markDisequal(BasicValueFactory &BV, RangeSet::Factory &F, + ProgramStateRef State, EquivalenceClass First, + EquivalenceClass Second); + LLVM_NODISCARD inline ProgramStateRef + markDisequal(BasicValueFactory &BV, RangeSet::Factory &F, + ProgramStateRef State, EquivalenceClass Other) const; + LLVM_NODISCARD static inline ClassSet + getDisequalClasses(ProgramStateRef State, SymbolRef Sym); + LLVM_NODISCARD inline ClassSet + getDisequalClasses(ProgramStateRef State) const; + + LLVM_NODISCARD static inline Optional + areEqual(ProgramStateRef State, SymbolRef First, SymbolRef Second); + /// Check equivalence data for consistency. LLVM_NODISCARD LLVM_ATTRIBUTE_UNUSED static bool isClassDataConsistent(ProgramStateRef State); @@ -496,6 +517,11 @@ ProgramStateRef State, SymbolSet Members, EquivalenceClass Other, SymbolSet OtherMembers); + static inline void + addToDisequalityInfo(DisequalityMapTy &Info, ConstraintRangeTy &Constraints, + BasicValueFactory &BV, RangeSet::Factory &F, + ProgramStateRef State, EquivalenceClass First, + EquivalenceClass Second); /// This is a unique identifier of the class. uintptr_t ID; @@ -510,17 +536,6 @@ // Constraint functions //===----------------------------------------------------------------------===// -LLVM_NODISCARD inline ProgramStateRef setConstraint(ProgramStateRef State, - EquivalenceClass Class, - RangeSet Constraint) { - return State->set(Class, Constraint); -} - -LLVM_NODISCARD inline ProgramStateRef -setConstraint(ProgramStateRef State, SymbolRef Sym, RangeSet Constraint) { - return setConstraint(State, EquivalenceClass::find(State, Sym), Constraint); -} - LLVM_NODISCARD inline const RangeSet *getConstraint(ProgramStateRef State, EquivalenceClass Class) { return State->get(Class); @@ -697,10 +712,11 @@ class SymbolicRangeInferrer : public SymExprVisitor { public: + template static RangeSet inferRange(BasicValueFactory &BV, RangeSet::Factory &F, - ProgramStateRef State, SymbolRef Sym) { + ProgramStateRef State, SourceType Origin) { SymbolicRangeInferrer Inferrer(BV, F, State); - return Inferrer.infer(Sym); + return Inferrer.infer(Origin); } RangeSet VisitSymExpr(SymbolRef Sym) { @@ -750,6 +766,8 @@ } RangeSet infer(SymbolRef Sym) { + RangeSet Result = RangeFactory.getEmptySet(); + if (Optional ConstraintBasedRange = intersect( ValueFactory, RangeFactory, getConstraint(State, Sym), // If Sym is a difference of symbols A - B, then maybe we have range @@ -759,16 +777,26 @@ // calculate the effective range set by intersecting the range set // for A - B and the negated range set of B - A. getRangeForInvertedSub(Sym), getRangeForEqualities(Sym))) { - return *ConstraintBasedRange; + Result = *ConstraintBasedRange; } - // If Sym is a comparison expression (except <=>), // find any other comparisons with the same operands. // See function description. - if (Optional CmpRangeSet = getRangeForComparisonSymbol(Sym)) - return *CmpRangeSet; + else if (Optional CmpRangeSet = + getRangeForComparisonSymbol(Sym)) { + Result = *CmpRangeSet; + } else { + Result = Visit(Sym); + } + + return Result; + } - return Visit(Sym); + RangeSet infer(EquivalenceClass Class) { + if (const RangeSet *AssociatedConstraint = getConstraint(State, Class)) + return *AssociatedConstraint; + + return infer(Class.getType()); } /// Infer range information solely from the type. @@ -1039,19 +1067,15 @@ if (!Equality) return llvm::None; - EquivalenceClass LHS = EquivalenceClass::find(State, Equality->Left); - EquivalenceClass RHS = EquivalenceClass::find(State, Equality->Right); - - if (LHS != RHS) - // Can't really say anything at this point. - // We can add more logic here if we track disequalities as well. - return llvm::None; - - // At this point, operands of the equality operation are known to be equal. - if (Equality->IsEquality) { - return getTrueRange(Sym->getType()); + if (Optional AreEqual = EquivalenceClass::areEqual( + State, Equality->Left, Equality->Right)) { + if (*AreEqual == Equality->IsEquality) { + return getTrueRange(Sym->getType()); + } + return getFalseRange(Sym->getType()); } - return getFalseRange(Sym->getType()); + + return llvm::None; } RangeSet getTrueRange(QualType T) { @@ -1308,6 +1332,7 @@ RangeSet::Factory F; RangeSet getRange(ProgramStateRef State, SymbolRef Sym); + RangeSet getRange(ProgramStateRef State, EquivalenceClass Class); RangeSet getSymLTRange(ProgramStateRef St, SymbolRef Sym, const llvm::APSInt &Int, @@ -1361,12 +1386,42 @@ ProgramStateRef trackDisequality(ProgramStateRef State, SymbolRef LHS, SymbolRef RHS) { - // TODO: track inequalities - return State; + return EquivalenceClass::markDisequal(getBasicVals(), F, State, LHS, RHS); } ProgramStateRef trackEquality(ProgramStateRef State, SymbolRef LHS, - SymbolRef RHS); + SymbolRef RHS) { + return EquivalenceClass::merge(getBasicVals(), F, State, LHS, RHS); + } + + LLVM_NODISCARD inline ProgramStateRef setConstraint(ProgramStateRef State, + EquivalenceClass Class, + RangeSet Constraint) { + ConstraintRangeTy Constraints = State->get(); + ConstraintRangeTy::Factory &CF = State->get_context(); + + // Add new constraint. + Constraints = CF.add(Constraints, Class, Constraint); + + // There is a chance that we might need to update constraints for the + // classes that are known to be disequal to Class. + // + // In order for this to be even possible, the new constraint should + // be simply a constant because we can't reason about range disequalities. + if (const llvm::APSInt *Point = Constraint.getConcreteValue()) + for (EquivalenceClass DisequalClass : Class.getDisequalClasses(State)) { + RangeSet UpdatedConstraint = + getRange(State, DisequalClass).Delete(getBasicVals(), F, *Point); + Constraints = CF.add(Constraints, DisequalClass, UpdatedConstraint); + } + + return State->set(Constraints); + } + + LLVM_NODISCARD inline ProgramStateRef + setConstraint(ProgramStateRef State, SymbolRef Sym, RangeSet Constraint) { + return setConstraint(State, EquivalenceClass::find(State, Sym), Constraint); + } }; } // end anonymous namespace @@ -1501,11 +1556,15 @@ // 2. Get ALL equivalence-related maps ClassMapTy Classes = State->get(); - ClassMapTy::Factory &CF = State->get_context(); + ClassMapTy::Factory &CMF = State->get_context(); ClassMembersTy Members = State->get(); ClassMembersTy::Factory &MF = State->get_context(); + DisequalityMapTy DisequalityInfo = State->get(); + DisequalityMapTy::Factory &DF = State->get_context(); + + ClassSet::Factory &CF = State->get_context(); SymbolSet::Factory &F = getMembersFactory(State); // 2. Merge members of the Other class into the current class. @@ -1513,7 +1572,7 @@ for (SymbolRef Sym : OtherMembers) { NewClassMembers = F.add(NewClassMembers, Sym); // *this is now the class for all these new symbols. - Classes = CF.add(Classes, Sym, *this); + Classes = CMF.add(Classes, Sym, *this); } // 3. Adjust member mapping. @@ -1523,7 +1582,33 @@ // Now only the current class is mapped to all the symbols. Members = MF.add(Members, *this, NewClassMembers); - // 4. Update the state + // 4. Update disequality relations + if (const ClassSet *DisequalToOther = DisequalityInfo.lookup(Other)) { + ClassSet DisequalToThis = getDisequalClasses(State); + DisequalityInfo = DF.remove(DisequalityInfo, Other); + + for (EquivalenceClass DisequalClass : *DisequalToOther) { + DisequalToThis = CF.add(DisequalToThis, DisequalClass); + + // Disequality is a symmetric relation meaning that if + // DisequalToOther not null then the set for DisequalClass is not + // empty and has at least Other. + ClassSet OriginalSetLinkedToOther = + *DisequalityInfo.lookup(DisequalClass); + + // Other will be eliminated and we should replace it with the bigger + // united class. + ClassSet NewSet = CF.remove(OriginalSetLinkedToOther, Other); + NewSet = CF.add(NewSet, *this); + + DisequalityInfo = DF.add(DisequalityInfo, DisequalClass, NewSet); + } + + DisequalityInfo = DF.add(DisequalityInfo, *this, DisequalToThis); + State = State->set(DisequalityInfo); + } + + // 5. Update the state State = State->set(Classes); State = State->set(Members); @@ -1554,6 +1639,114 @@ return isTrivial(State) && Reaper.isDead(getRepresentativeSymbol()); } +inline ProgramStateRef EquivalenceClass::markDisequal(BasicValueFactory &VF, + RangeSet::Factory &RF, + ProgramStateRef State, + SymbolRef First, + SymbolRef Second) { + return markDisequal(VF, RF, State, find(State, First), find(State, Second)); +} + +inline ProgramStateRef EquivalenceClass::markDisequal(BasicValueFactory &VF, + RangeSet::Factory &RF, + ProgramStateRef State, + EquivalenceClass First, + EquivalenceClass Second) { + return First.markDisequal(VF, RF, State, Second); +} + +inline ProgramStateRef +EquivalenceClass::markDisequal(BasicValueFactory &VF, RangeSet::Factory &RF, + ProgramStateRef State, + EquivalenceClass Other) const { + // If we know that two classes are equal, we can only produce an infeasible + // state. + if (*this == Other) { + return nullptr; + } + + DisequalityMapTy DisequalityInfo = State->get(); + ConstraintRangeTy Constraints = State->get(); + + // Disequality is a symmetric relation, so if we mark A as disequal to B, + // we should also mark B as disequalt to A. + addToDisequalityInfo(DisequalityInfo, Constraints, VF, RF, State, *this, + Other); + addToDisequalityInfo(DisequalityInfo, Constraints, VF, RF, State, Other, + *this); + + State = State->set(DisequalityInfo); + State = State->set(Constraints); + + return State; +} + +inline void EquivalenceClass::addToDisequalityInfo( + DisequalityMapTy &Info, ConstraintRangeTy &Constraints, + BasicValueFactory &VF, RangeSet::Factory &RF, ProgramStateRef State, + EquivalenceClass First, EquivalenceClass Second) { + + // 1. Get all of the required factories. + DisequalityMapTy::Factory &F = State->get_context(); + ClassSet::Factory &CF = State->get_context(); + ConstraintRangeTy::Factory &CRF = State->get_context(); + + // 2. Add Second to the set of classes disequal to First. + const ClassSet *CurrentSet = Info.lookup(First); + ClassSet NewSet = CurrentSet ? *CurrentSet : CF.getEmptySet(); + NewSet = CF.add(NewSet, Second); + + Info = F.add(Info, First, NewSet); + + // 3. If Second is known to be a constant, we can delete this point + // from the constraint asociated with First. + // + // So, if Second == 10, it means that First != 10. + // At the same time, the same logic does not apply to ranges. + if (const RangeSet *SecondConstraint = Constraints.lookup(Second)) + if (const llvm::APSInt *Point = SecondConstraint->getConcreteValue()) { + + RangeSet FirstConstraint = SymbolicRangeInferrer::inferRange( + VF, RF, State, First.getRepresentativeSymbol()); + + FirstConstraint = FirstConstraint.Delete(VF, RF, *Point); + Constraints = CRF.add(Constraints, First, FirstConstraint); + } +} + +inline Optional EquivalenceClass::areEqual(ProgramStateRef State, + SymbolRef FirstSym, + SymbolRef SecondSym) { + EquivalenceClass First = find(State, FirstSym); + EquivalenceClass Second = find(State, SecondSym); + + // The same equivalence class => symbols are equal. + if (First == Second) + return true; + + // Let's check if know anything about these two classes being not equal to + // each other. + ClassSet DisequalToFirst = First.getDisequalClasses(State); + if (DisequalToFirst.contains(Second)) + return false; + + // It is not clear. + return llvm::None; +} + +inline ClassSet EquivalenceClass::getDisequalClasses(ProgramStateRef State, + SymbolRef Sym) { + return find(State, Sym).getDisequalClasses(State); +} + +inline ClassSet +EquivalenceClass::getDisequalClasses(ProgramStateRef State) const { + if (const ClassSet *DisequalClasses = State->get(*this)) + return *DisequalClasses; + + return State->get_context().getEmptySet(); +} + bool EquivalenceClass::isClassDataConsistent(ProgramStateRef State) { ClassMembersTy Members = State->get(); @@ -1568,6 +1761,28 @@ } } + DisequalityMapTy Disequalities = State->get(); + for (std::pair DisequalityInfo : Disequalities) { + EquivalenceClass Class = DisequalityInfo.first; + ClassSet DisequalClasses = DisequalityInfo.second; + + // There is no use in keeping empty sets in the map. + if (DisequalClasses.isEmpty()) + return false; + + // Disequality is symmetrical, i.e. for every Class A and B that A != B, + // B != A should also be true. + for (EquivalenceClass DisequalClass : DisequalClasses) { + const ClassSet *DisequalToDisequalClasses = + Disequalities.lookup(DisequalClass); + + // It should be a set of at least one element: Class + if (!DisequalToDisequalClasses || + !DisequalToDisequalClasses->contains(Class)) + return false; + } + } + return true; } @@ -1674,9 +1889,45 @@ ClassMapTy NewMap = Map; ClassMapTy::Factory &ClassFactory = State->get_context(); + DisequalityMapTy Disequalities = State->get(); + DisequalityMapTy::Factory &DisequalityFactory = + State->get_context(); + ClassSet::Factory &ClassSetFactory = State->get_context(); + bool ClassMapChanged = false; bool MembersMapChanged = false; bool ConstraintMapChanged = false; + bool DisequalitiesChanged = false; + + auto removeDeadClass = [&](EquivalenceClass Class) { + // Remove associated constraint ranges. + Constraints = ConstraintFactory.remove(Constraints, Class); + ConstraintMapChanged = true; + + // Update disequality information to not hold any information on the + // removed class. + if (const ClassSet *DisequalClasses = Disequalities.lookup(Class)) { + for (EquivalenceClass DisequalClass : *DisequalClasses) { + const ClassSet *DisequalToDisequalSet = + Disequalities.lookup(DisequalClass); + // DisequalToDisequalSet is guaranteed to be non-null for consistent + // disequality info. + ClassSet NewSet = ClassSetFactory.remove(*DisequalToDisequalSet, Class); + + // No need in keeping an empty set. + if (NewSet.isEmpty()) { + Disequalities = + DisequalityFactory.remove(Disequalities, DisequalClass); + } else { + Disequalities = + DisequalityFactory.add(Disequalities, DisequalClass, NewSet); + } + } + // Remove the data for the class + Disequalities = DisequalityFactory.remove(Disequalities, Class); + DisequalitiesChanged = true; + } + }; // 1. Let's see if dead symbols are trivial and have associated constraints. for (std::pair ClassConstraintPair : @@ -1684,8 +1935,7 @@ EquivalenceClass Class = ClassConstraintPair.first; if (Class.isTriviallyDead(State, SymReaper)) { // If this class is trivial, we can remove its constraints right away. - Constraints = ConstraintFactory.remove(Constraints, Class); - ConstraintMapChanged = true; + removeDeadClass(Class); } } @@ -1703,6 +1953,7 @@ // and their constraints. for (std::pair ClassMembersPair : ClassMembersMap) { + EquivalenceClass Class = ClassMembersPair.first; SymbolSet LiveMembers = ClassMembersPair.second; bool MembersChanged = false; @@ -1721,17 +1972,14 @@ if (LiveMembers.isEmpty()) { // The class is dead now, we need to wipe it out of the members map... - NewClassMembersMap = - EMFactory.remove(NewClassMembersMap, ClassMembersPair.first); + NewClassMembersMap = EMFactory.remove(NewClassMembersMap, Class); // ...and remove all of its constraints. - Constraints = - ConstraintFactory.remove(Constraints, ClassMembersPair.first); - ConstraintMapChanged = true; + removeDeadClass(Class); } else { // We need to change the members associated with the class. - NewClassMembersMap = EMFactory.add(NewClassMembersMap, - ClassMembersPair.first, LiveMembers); + NewClassMembersMap = + EMFactory.add(NewClassMembersMap, Class, LiveMembers); } } @@ -1747,6 +1995,9 @@ if (ConstraintMapChanged) State = State->set(Constraints); + if (DisequalitiesChanged) + State = State->set(Disequalities); + assert(EquivalenceClass::isClassDataConsistent(State)); return State; @@ -1757,6 +2008,11 @@ return SymbolicRangeInferrer::inferRange(getBasicVals(), F, State, Sym); } +RangeSet RangeConstraintManager::getRange(ProgramStateRef State, + EquivalenceClass Class) { + return SymbolicRangeInferrer::inferRange(getBasicVals(), F, State, Class); +} + //===------------------------------------------------------------------------=== // assumeSymX methods: protected interface for RangeConstraintManager. //===------------------------------------------------------------------------===/ @@ -1981,12 +2237,6 @@ return New.isEmpty() ? nullptr : setConstraint(State, Sym, New); } -ProgramStateRef RangeConstraintManager::trackEquality(ProgramStateRef State, - SymbolRef LHS, - SymbolRef RHS) { - return EquivalenceClass::merge(getBasicVals(), F, State, LHS, RHS); -} - //===----------------------------------------------------------------------===// // Pretty-printing. //===----------------------------------------------------------------------===// diff --git a/clang/test/Analysis/equality_tracking.c b/clang/test/Analysis/equality_tracking.c --- a/clang/test/Analysis/equality_tracking.c +++ b/clang/test/Analysis/equality_tracking.c @@ -23,9 +23,8 @@ return; } clang_analyzer_eval((a - b) == 0); // expected-warning{{FALSE}} - // FIXME: we should track disequality information as well - clang_analyzer_eval(b == a); // expected-warning{{UNKNOWN}} - clang_analyzer_eval(b != a); // expected-warning{{UNKNOWN}} + clang_analyzer_eval(b == a); // expected-warning{{FALSE}} + clang_analyzer_eval(b != a); // expected-warning{{TRUE}} } void zeroImpliesReversedEqual(int a, int b) { @@ -36,9 +35,8 @@ return; } clang_analyzer_eval((b - a) == 0); // expected-warning{{FALSE}} - // FIXME: we should track disequality information as well - clang_analyzer_eval(b == a); // expected-warning{{UNKNOWN}} - clang_analyzer_eval(b != a); // expected-warning{{UNKNOWN}} + clang_analyzer_eval(b == a); // expected-warning{{FALSE}} + clang_analyzer_eval(b != a); // expected-warning{{TRUE}} } void canonicalEqual(int a, int b) { @@ -73,8 +71,7 @@ if (a != b && b == c) { if (c == 42) { clang_analyzer_eval(b == 42); // expected-warning{{TRUE}} - // FIXME: we should track disequality information as well - clang_analyzer_eval(a != 42); // expected-warning{{UNKNOWN}} + clang_analyzer_eval(a != 42); // expected-warning{{TRUE}} } } } @@ -144,8 +141,31 @@ if (a != b && b == c) { if (c == NULL) { - // FIXME: we should track disequality information as well - clang_analyzer_eval(a != NULL); // expected-warning{{UNKNOWN}} + clang_analyzer_eval(a != NULL); // expected-warning{{TRUE}} + } + } +} + +void testDisequalitiesAfter(int a, int b, int c) { + if (a >= 10 && b <= 42) { + if (a == b && c == 15 && c != a) { + clang_analyzer_eval(b != c); // expected-warning{{TRUE}} + clang_analyzer_eval(a != 15); // expected-warning{{TRUE}} + clang_analyzer_eval(b != 15); // expected-warning{{TRUE}} + clang_analyzer_eval(b >= 10); // expected-warning{{TRUE}} + clang_analyzer_eval(a <= 42); // expected-warning{{TRUE}} + } + } +} + +void testDisequalitiesBefore(int a, int b, int c) { + if (a >= 10 && b <= 42 && c == 15) { + if (a == b && c != a) { + clang_analyzer_eval(b != c); // expected-warning{{TRUE}} + clang_analyzer_eval(a != 15); // expected-warning{{TRUE}} + clang_analyzer_eval(b != 15); // expected-warning{{TRUE}} + clang_analyzer_eval(b >= 10); // expected-warning{{TRUE}} + clang_analyzer_eval(a <= 42); // expected-warning{{TRUE}} } } } diff --git a/clang/test/Analysis/mutually_exclusive_null_fp.cpp b/clang/test/Analysis/mutually_exclusive_null_fp.cpp new file mode 100644 --- /dev/null +++ b/clang/test/Analysis/mutually_exclusive_null_fp.cpp @@ -0,0 +1,26 @@ +// RUN: %clang_analyze_cc1 -analyzer-checker=core -verify %s + +// rdar://problem/56586853 +// expected-no-diagnostics + +struct Data { + int x; + Data *data; +}; + +int compare(Data &a, Data &b) { + Data *aData = a.data; + Data *bData = b.data; + + // Covers the cases where both pointers are null as well as both pointing to the same buffer. + if (aData == bData) + return 0; + + if (aData && !bData) + return 1; + + if (!aData && bData) + return -1; + + return compare(*aData, *bData); +}