diff --git a/clang/lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp --- a/clang/lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/GenericTaintChecker.cpp @@ -36,7 +36,8 @@ namespace { class GenericTaintChecker - : public Checker, check::PreStmt> { + : public Checker, check::PreStmt, + check::PostStmt> { public: static void *getTag() { static int Tag; @@ -47,6 +48,10 @@ void checkPreStmt(const CallExpr *CE, CheckerContext &C) const; + /// Heuristic to cleanse taint from symbolic expressions if that is used in + /// comparison expressions. + void checkPostStmt(const BinaryOperator *BinOp, CheckerContext &Ctx) const; + void printState(raw_ostream &Out, ProgramStateRef State, const char *NL, const char *Sep) const override; @@ -514,6 +519,75 @@ return TaintPropagationRule(); } +static void +collectAllTaintedSymbolsRecursively(SymbolRef Sym, ProgramStateRef State, + SmallVector &result) { + switch (Sym->getKind()) { + case SymExpr::IntSymExprKind: { + const auto *IntSym = cast(Sym); + collectAllTaintedSymbolsRecursively(IntSym->getRHS(), State, result); + break; + } + case SymExpr::SymIntExprKind: { + const auto *SymInt = cast(Sym); + collectAllTaintedSymbolsRecursively(SymInt->getLHS(), State, result); + break; + } + case SymExpr::SymSymExprKind: { + const auto *SymSym = cast(Sym); + collectAllTaintedSymbolsRecursively(SymSym->getLHS(), State, result); + collectAllTaintedSymbolsRecursively(SymSym->getRHS(), State, result); + break; + } + default: + if (taint::isTainted(State, Sym)) + result.push_back(Sym); + } +} + +/// If a comparison operator has exactly one tainted operand +/// remove all tainted symbols that the operand depends on. +/// Ignores (in)equality operator calls checking against NULL. +void GenericTaintChecker::checkPostStmt(const BinaryOperator *BinOp, + CheckerContext &Ctx) const { + // Handle only (<,<=,>,>=,==,!=) operators. + if (!BinOp->isComparisonOp()) + return; + + SymbolRef SymLHS = Ctx.getSVal(BinOp->getLHS()).getAsSymExpr(); + SymbolRef SymRHS = Ctx.getSVal(BinOp->getRHS()).getAsSymExpr(); + + ProgramStateRef State = Ctx.getState(); + const bool TaintedLHS = taint::isTainted(State, SymLHS); + const bool TaintedRHS = taint::isTainted(State, SymRHS); + + // Do nothing if both operands are tainted. + if (TaintedLHS && TaintedRHS) + return; + + // Do nothing if none of the operands are tainted. + if (!TaintedLHS && !TaintedRHS) + return; + + // Ignore comparisons (==,!=) of tainted pointers and NULL. + if (BinOp->isEqualityOp()) { + const Expr *OtherArgument = TaintedLHS ? BinOp->getRHS() : BinOp->getLHS(); + const bool IsOtherNullExpr = OtherArgument->isNullPointerConstant( + Ctx.getASTContext(), Expr::NPC_ValueDependentIsNotNull); + if (IsOtherNullExpr) + return; + } + + // Remove taint. + SmallVector TaintedSubsymbols; + collectAllTaintedSymbolsRecursively((TaintedLHS ? SymLHS : SymRHS), State, + TaintedSubsymbols); + for (SymbolRef Sym : TaintedSubsymbols) + State = taint::removeTaint(State, Sym); + + Ctx.addTransition(State); +} + void GenericTaintChecker::checkPreStmt(const CallExpr *CE, CheckerContext &C) const { Optional FData = FunctionData::create(CE, C); diff --git a/clang/lib/StaticAnalyzer/Checkers/Taint.h b/clang/lib/StaticAnalyzer/Checkers/Taint.h --- a/clang/lib/StaticAnalyzer/Checkers/Taint.h +++ b/clang/lib/StaticAnalyzer/Checkers/Taint.h @@ -45,6 +45,9 @@ const MemRegion *R, TaintTagType Kind = TaintTagGeneric); +LLVM_NODISCARD ProgramStateRef removeTaint(ProgramStateRef State, const Stmt *S, + const LocationContext *LCtx); + LLVM_NODISCARD ProgramStateRef removeTaint(ProgramStateRef State, SVal V); LLVM_NODISCARD ProgramStateRef removeTaint(ProgramStateRef State, diff --git a/clang/lib/StaticAnalyzer/Checkers/Taint.cpp b/clang/lib/StaticAnalyzer/Checkers/Taint.cpp --- a/clang/lib/StaticAnalyzer/Checkers/Taint.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/Taint.cpp @@ -92,6 +92,11 @@ return NewState; } +ProgramStateRef taint::removeTaint(ProgramStateRef State, const Stmt *S, + const LocationContext *LCtx) { + return taint::removeTaint(State, State->getSVal(S, LCtx)); +} + ProgramStateRef taint::removeTaint(ProgramStateRef State, SVal V) { SymbolRef Sym = V.getAsSymbol(); if (Sym) diff --git a/clang/test/Analysis/taint-tester.c b/clang/test/Analysis/taint-tester.c --- a/clang/test/Analysis/taint-tester.c +++ b/clang/test/Analysis/taint-tester.c @@ -67,11 +67,14 @@ int y = (in << (x << in)) * 5;// expected-warning + {{tainted}} // The next line tests integer to integer cast. int z = y & inn; // expected-warning + {{tainted}} - if (y == 5) // expected-warning + {{tainted}} - m = z | z;// expected-warning + {{tainted}} - else - m = inn; - int mm = m; // expected-warning + {{tainted}} + if (y == 5) { // expected-warning + {{tainted}} + // Since the only tainted symbol y depended on was the value of x, the + // check on y in the condition marked the value of x not tainted anymore. + m = z | z; // no warning + } else { + m = inn; // no warning + } + int mm = m; // no warning } // Test getenv. @@ -168,6 +171,43 @@ free(line); // expected-warning + {{tainted}} } +int conditionRemovesTaintTest() { + int idx; + scanf("%d", &idx); // The value of idx become tainted. + // Relational operators comparing a tainted value to a non-tainted will + // remove taint. + if (idx < 0 || 42 < idx) { // expected-warning + {{tainted}} + int idx2 = idx; // no warning + return -1; + } + // Not tainted now, since appeared in the condition previously. + return idx; // no warning +} + +int conditionDoesNotRemoveTaintTest() { + int idx1, idx2; + scanf("%d %d", &idx1, &idx2); + + // Bot operands of the comparison are tainted. + // Taint won't be removed. + if (idx1 < idx2) { // expected-warning + {{tainted}} + int tmp = idx1; // expected-warning + {{tainted}} + return -1; + } + + + int sum = idx1 + idx2; // expected-warning + {{tainted}} + + // Relation operator removes taint from all dependent symbolic expressions. + if (0 <= sum && sum < 42) { // expected-warning {{tainted}} + int tmp1 = idx1; // no warning + int tmp2 = idx2; // no warning + int tmp3 = sum; // no warning + } + + return idx1 + idx2 + sum; // no warning +} + // Test propagation functions - the ones that propagate taint from arguments to // return value, ptr arguments.