Index: include/clang/StaticAnalyzer/Core/PathSensitive/ProgramState.h =================================================================== --- include/clang/StaticAnalyzer/Core/PathSensitive/ProgramState.h +++ include/clang/StaticAnalyzer/Core/PathSensitive/ProgramState.h @@ -43,6 +43,8 @@ ProgramStateManager &, SubEngine *); typedef std::unique_ptr(*StoreManagerCreator)( ProgramStateManager &); +typedef llvm::ImmutableSet TaintedSymRegions; +typedef llvm::ImmutableSetRef TaintedSymRegionsRef; //===----------------------------------------------------------------------===// // ProgramStateTrait - Traits used by the Generic Data Map of a ProgramState. @@ -87,6 +89,7 @@ Store store; // Maps a location to its current value. GenericDataMap GDM; // Custom data stored by a client of this class. unsigned refCount; + TaintedSymRegions::Factory TSRFactory; /// makeWithStore - Return a ProgramState with the same values as the current /// state with the exception of using the specified Store. Index: include/clang/StaticAnalyzer/Core/PathSensitive/TaintManager.h =================================================================== --- include/clang/StaticAnalyzer/Core/PathSensitive/TaintManager.h +++ include/clang/StaticAnalyzer/Core/PathSensitive/TaintManager.h @@ -35,6 +35,16 @@ static void *GDMIndex() { static int index = 0; return &index; } }; +/// The GDM component mapping derived symbols' parent symbols to their +/// underlying regions. This is used to efficiently check whether a symbol is +/// tainted when it represents a sub-region of a tainted symbol. +struct DerivedSymTaint {}; +typedef llvm::ImmutableMap DerivedSymTaintImpl; +template<> struct ProgramStateTrait + : public ProgramStatePartialTrait { + static void *GDMIndex() { static int index = 0; return &index; } +}; + class TaintManager { TaintManager() {} Index: lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp =================================================================== --- lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp +++ lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp @@ -72,8 +72,6 @@ /// covers the entire region, e.g. we avoid false positives by not returning /// a default bindingc for an entire struct if the symbol for only a single /// field or element within it is requested. - // TODO: Return an appropriate symbol for sub-fields/elements of an LCV so - // that they are also appropriately tainted. static SymbolRef getLCVSymbol(CheckerContext &C, nonloc::LazyCompoundVal &LCV); @@ -479,19 +477,21 @@ // getLCVSymbol() is reached in a PostStmt so we can always expect a default // binding to exist if one is present. - if (Optional binding = StoreMgr.getDefaultBinding(LCV)) { - SymbolRef Sym = binding->getAsSymbol(); - if (!Sym) - return nullptr; - - // If the LCV covers an entire base region return the default conjured symbol. - if (LCV.getRegion() == LCV.getRegion()->getBaseRegion()) - return Sym; - } + Optional binding = StoreMgr.getDefaultBinding(LCV); + if (!binding) + return nullptr; - // Otherwise, return a nullptr as there's not yet a functional way to taint - // sub-regions of LCVs. - return nullptr; + SymbolRef Sym = binding->getAsSymbol(); + if (!Sym) + return nullptr; + + // If the LCV covers an entire base region return the default conjured symbol. + if (LCV.getRegion() == LCV.getRegion()->getBaseRegion()) + return Sym; + + // Otherwise, return a derived symbol indicating only a sub-region is tainted + SymbolManager &SM = C.getSymbolManager(); + return SM.getDerivedSymbol(Sym, LCV.getRegion()); } SymbolRef GenericTaintChecker::getPointedToSymbol(CheckerContext &C, Index: lib/StaticAnalyzer/Core/ProgramState.cpp =================================================================== --- lib/StaticAnalyzer/Core/ProgramState.cpp +++ lib/StaticAnalyzer/Core/ProgramState.cpp @@ -671,6 +671,19 @@ ProgramStateRef NewState = set(Sym, Kind); assert(NewState); + + if (const SymbolDerived *SD = dyn_cast(Sym)) { + TaintedSymRegionsRef SymRegions(0, TSRFactory.getTreeFactory()); + + if (const TaintedSymRegionsRef *SavedRegions = + get(SD->getParentSymbol())) + SymRegions = *SavedRegions; + + SymRegions = SymRegions.add(SD->getRegion()); + NewState = NewState->set(SD->getParentSymbol(), SymRegions); + assert(NewState); + } + return NewState; } @@ -723,15 +736,34 @@ const TaintTagType *Tag = get(*SI); Tainted = (Tag && *Tag == Kind); - // If this is a SymbolDerived with a tainted parent, it's also tainted. - if (const SymbolDerived *SD = dyn_cast(*SI)) + if (const SymbolDerived *SD = dyn_cast(*SI)) { + // If this is a SymbolDerived with a tainted parent, it's also tainted. Tainted = Tainted || isTainted(SD->getParentSymbol(), Kind); + // If this is a SymbolDerived with the same parent symbol as another + // tainted SymbolDerived and a region that's a sub-region of that tainted + // symbol, it's also tainted. + if (const TaintedSymRegionsRef *SymRegions = + get(SD->getParentSymbol())) { + const TypedValueRegion *R = SD->getRegion(); + for (TaintedSymRegionsRef::iterator I = SymRegions->begin(), + E = SymRegions->end(); + I != E; ++I) { + // FIXME: The logic to identify tainted regions could be more + // complete. For example, this would not currently identify + // overlapping fields in a union as tainted. To identify this we can + // check for overlapping/nested byte offsets. + if (R == *I || R->isSubRegionOf(*I)) + return true; + } + } + } + // If memory region is tainted, data is also tainted. if (const SymbolRegionValue *SRV = dyn_cast(*SI)) Tainted = Tainted || isTainted(SRV->getRegion(), Kind); - // If If this is a SymbolCast from a tainted value, it's also tainted. + // If this is a SymbolCast from a tainted value, it's also tainted. if (const SymbolCast *SC = dyn_cast(*SI)) Tainted = Tainted || isTainted(SC->getOperand(), Kind); Index: lib/StaticAnalyzer/Core/RegionStore.cpp =================================================================== --- lib/StaticAnalyzer/Core/RegionStore.cpp +++ lib/StaticAnalyzer/Core/RegionStore.cpp @@ -496,7 +496,10 @@ Optional getDefaultBinding(Store S, const MemRegion *R) override { RegionBindingsRef B = getRegionBindings(S); - return B.getDefaultBinding(R); + // Default bindings are always applied over a base region so look up the + // base region's default binding, otherwise the lookup will fail when R + // is at an offset from R->getBaseRegion(). + return B.getDefaultBinding(R->getBaseRegion()); } SVal getBinding(RegionBindingsConstRef B, Loc L, QualType T = QualType()); Index: test/Analysis/taint-generic.c =================================================================== --- test/Analysis/taint-generic.c +++ test/Analysis/taint-generic.c @@ -192,20 +192,41 @@ void testStructArray() { struct { - char buf[16]; - struct { - int length; - } st[1]; - } tainted; + int length; + } tainted[4]; - char buffer[16]; + char dstbuf[16], srcbuf[16]; int sock; sock = socket(AF_INET, SOCK_STREAM, 0); - read(sock, &tainted.buf[0], sizeof(tainted.buf)); - read(sock, &tainted.st[0], sizeof(tainted.st)); - // FIXME: tainted.st[0].length should be marked tainted - __builtin_memcpy(buffer, tainted.buf, tainted.st[0].length); // no-warning + __builtin_memset(srcbuf, 0, sizeof(srcbuf)); + + read(sock, &tainted[0], sizeof(tainted)); + __builtin_memcpy(dstbuf, srcbuf, tainted[0].length); // expected-warning {{Untrusted data is used to specify the buffer size}} + + __builtin_memset(&tainted, 0, sizeof(tainted)); + read(sock, &tainted, sizeof(tainted)); + __builtin_memcpy(dstbuf, srcbuf, tainted[0].length); // expected-warning {{Untrusted data is used to specify the buffer size}} + + __builtin_memset(&tainted, 0, sizeof(tainted)); + // If we taint element 1, we should not raise an alert on taint for element 0 or element 2 + read(sock, &tainted[1], sizeof(tainted)); + __builtin_memcpy(dstbuf, srcbuf, tainted[0].length); // no-warning + __builtin_memcpy(dstbuf, srcbuf, tainted[2].length); // no-warning +} + +void testUnion() { + union { + int x; + char y[4]; + } tainted; + + char buffer[4]; + + int sock = socket(AF_INET, SOCK_STREAM, 0); + read(sock, &tainted.y, sizeof(tainted.y)); + // FIXME: overlapping regions aren't detected by isTainted yet + __builtin_memcpy(buffer, tainted.y, tainted.x); } int testDivByZero() {