diff --git a/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h b/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h --- a/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h +++ b/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h @@ -34,73 +34,40 @@ private: class Visitor; - void reportBinOp(const ast_matchers::MatchFinder::MatchResult &Result, - const BinaryOperator *Op); - - void matchBoolCondition(ast_matchers::MatchFinder *Finder, bool Value, - StringRef BooleanId); - - void matchTernaryResult(ast_matchers::MatchFinder *Finder, bool Value, - StringRef Id); - - void matchIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value, - StringRef Id); + void reportBinOp(const ASTContext &Context, const BinaryOperator *Op); void matchIfAssignsBool(ast_matchers::MatchFinder *Finder, bool Value, StringRef Id); - void matchCompoundIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value, - StringRef Id); - - void matchCaseIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value, - StringRef Id); + void replaceWithThenStatement(const ASTContext &Context, + const IfStmt *IfStatement, + const Expr *BoolLiteral); - void matchDefaultIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value, - StringRef Id); + void replaceWithElseStatement(const ASTContext &Context, + const IfStmt *IfStatement, + const Expr *BoolLiteral); - void matchLabelIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value, - StringRef Id); + void replaceWithCondition(const ASTContext &Context, + const ConditionalOperator *Ternary, bool Negated); - void - replaceWithThenStatement(const ast_matchers::MatchFinder::MatchResult &Result, - const Expr *BoolLiteral); - - void - replaceWithElseStatement(const ast_matchers::MatchFinder::MatchResult &Result, - const Expr *BoolLiteral); - - void - replaceWithCondition(const ast_matchers::MatchFinder::MatchResult &Result, - const ConditionalOperator *Ternary, bool Negated); - - void replaceWithReturnCondition( - const ast_matchers::MatchFinder::MatchResult &Result, const IfStmt *If, - bool Negated); + void replaceWithReturnCondition(const ASTContext &Context, const IfStmt *If, + const Expr *BoolLiteral, bool Negated); void replaceWithAssignment(const ast_matchers::MatchFinder::MatchResult &Result, const IfStmt *If, bool Negated); - void replaceCompoundReturnWithCondition( - const ast_matchers::MatchFinder::MatchResult &Result, - const CompoundStmt *Compound, bool Negated); - - void replaceCompoundReturnWithCondition( - const ast_matchers::MatchFinder::MatchResult &Result, bool Negated, - const IfStmt *If); - - void replaceCaseCompoundReturnWithCondition( - const ast_matchers::MatchFinder::MatchResult &Result, bool Negated); - - void replaceDefaultCompoundReturnWithCondition( - const ast_matchers::MatchFinder::MatchResult &Result, bool Negated); + void replaceCompoundReturnWithCondition(const ASTContext &Context, + const CompoundStmt *Compound, + const ReturnStmt *Ret, bool Negated); - void replaceLabelCompoundReturnWithCondition( - const ast_matchers::MatchFinder::MatchResult &Result, bool Negated); + void replaceCompoundReturnWithCondition(const ASTContext &Context, + const ReturnStmt *Ret, bool Negated, + const IfStmt *If); - void issueDiag(const ast_matchers::MatchFinder::MatchResult &Result, - SourceLocation Loc, StringRef Description, - SourceRange ReplacementRange, StringRef Replacement); + void issueDiag(const ASTContext &Result, SourceLocation Loc, + StringRef Description, SourceRange ReplacementRange, + StringRef Replacement); const bool ChainedConditionalReturn; const bool ChainedConditionalAssignment; diff --git a/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp b/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp --- a/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp +++ b/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp @@ -10,6 +10,7 @@ #include "SimplifyBooleanExprMatchers.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Lex/Lexer.h" +#include "llvm/ADT/PointerIntPair.h" #include #include @@ -22,45 +23,23 @@ namespace { -StringRef getText(const MatchFinder::MatchResult &Result, SourceRange Range) { +StringRef getText(const ASTContext &Context, SourceRange Range) { return Lexer::getSourceText(CharSourceRange::getTokenRange(Range), - *Result.SourceManager, - Result.Context->getLangOpts()); + Context.getSourceManager(), + Context.getLangOpts()); } -template -StringRef getText(const MatchFinder::MatchResult &Result, T &Node) { - return getText(Result, Node.getSourceRange()); +template StringRef getText(const ASTContext &Context, T &Node) { + return getText(Context, Node.getSourceRange()); } } // namespace -static constexpr char ConditionThenStmtId[] = "if-bool-yields-then"; -static constexpr char ConditionElseStmtId[] = "if-bool-yields-else"; -static constexpr char TernaryId[] = "ternary-bool-yields-condition"; -static constexpr char TernaryNegatedId[] = "ternary-bool-yields-not-condition"; -static constexpr char IfReturnsBoolId[] = "if-return"; -static constexpr char IfReturnsNotBoolId[] = "if-not-return"; -static constexpr char ThenLiteralId[] = "then-literal"; static constexpr char IfAssignVariableId[] = "if-assign-lvalue"; static constexpr char IfAssignLocId[] = "if-assign-loc"; static constexpr char IfAssignBoolId[] = "if-assign"; static constexpr char IfAssignNotBoolId[] = "if-assign-not"; static constexpr char IfAssignVarId[] = "if-assign-var"; -static constexpr char CompoundReturnId[] = "compound-return"; -static constexpr char CompoundIfId[] = "compound-if"; -static constexpr char CompoundBoolId[] = "compound-bool"; -static constexpr char CompoundNotBoolId[] = "compound-bool-not"; -static constexpr char CaseId[] = "case"; -static constexpr char CaseCompoundBoolId[] = "case-compound-bool"; -static constexpr char CaseCompoundNotBoolId[] = "case-compound-bool-not"; -static constexpr char DefaultId[] = "default"; -static constexpr char DefaultCompoundBoolId[] = "default-compound-bool"; -static constexpr char DefaultCompoundNotBoolId[] = "default-compound-bool-not"; -static constexpr char LabelId[] = "label"; -static constexpr char LabelCompoundBoolId[] = "label-compound-bool"; -static constexpr char LabelCompoundNotBoolId[] = "label-compound-bool-not"; -static constexpr char IfStmtId[] = "if"; static constexpr char SimplifyOperatorDiagnostic[] = "redundant boolean literal supplied to boolean operator"; @@ -69,18 +48,6 @@ static constexpr char SimplifyConditionalReturnDiagnostic[] = "redundant boolean literal in conditional return statement"; -static const Expr *getBoolLiteral(const MatchFinder::MatchResult &Result, - StringRef Id) { - if (const Expr *Literal = Result.Nodes.getNodeAs(Id)) - return Literal->getBeginLoc().isMacroID() ? nullptr : Literal; - if (const auto *Negated = Result.Nodes.getNodeAs(Id)) { - if (Negated->getOpcode() == UO_LNot && - isa(Negated->getSubExpr())) - return Negated->getBeginLoc().isMacroID() ? nullptr : Negated; - } - return nullptr; -} - static internal::BindableMatcher literalOrNegatedBool(bool Value) { return expr( anyOf(cxxBoolLiteral(equals(Value)), @@ -88,14 +55,6 @@ hasOperatorName("!")))); } -static internal::Matcher returnsBool(bool Value, - StringRef Id = "ignored") { - auto SimpleReturnsBool = returnStmt(has(literalOrNegatedBool(Value).bind(Id))) - .bind("returns-bool"); - return anyOf(SimpleReturnsBool, - compoundStmt(statementCountIs(1), has(SimpleReturnsBool))); -} - static bool needsParensAfterUnaryNegation(const Expr *E) { E = E->IgnoreImpCasts(); if (isa(E) || isa(E)) @@ -192,32 +151,29 @@ return !E->getType()->isBooleanType(); } -static std::string -compareExpressionToConstant(const MatchFinder::MatchResult &Result, - const Expr *E, bool Negated, const char *Constant) { +static std::string compareExpressionToConstant(const ASTContext &Context, + const Expr *E, bool Negated, + const char *Constant) { E = E->IgnoreImpCasts(); const std::string ExprText = - (isa(E) ? ("(" + getText(Result, *E) + ")") - : getText(Result, *E)) + (isa(E) ? ("(" + getText(Context, *E) + ")") + : getText(Context, *E)) .str(); return ExprText + " " + (Negated ? "!=" : "==") + " " + Constant; } -static std::string -compareExpressionToNullPtr(const MatchFinder::MatchResult &Result, - const Expr *E, bool Negated) { - const char *NullPtr = - Result.Context->getLangOpts().CPlusPlus11 ? "nullptr" : "NULL"; - return compareExpressionToConstant(Result, E, Negated, NullPtr); +static std::string compareExpressionToNullPtr(const ASTContext &Context, + const Expr *E, bool Negated) { + const char *NullPtr = Context.getLangOpts().CPlusPlus11 ? "nullptr" : "NULL"; + return compareExpressionToConstant(Context, E, Negated, NullPtr); } -static std::string -compareExpressionToZero(const MatchFinder::MatchResult &Result, const Expr *E, - bool Negated) { - return compareExpressionToConstant(Result, E, Negated, "0"); +static std::string compareExpressionToZero(const ASTContext &Context, + const Expr *E, bool Negated) { + return compareExpressionToConstant(Context, E, Negated, "0"); } -static std::string replacementExpression(const MatchFinder::MatchResult &Result, +static std::string replacementExpression(const ASTContext &Context, bool Negated, const Expr *E) { E = E->IgnoreParenBaseCasts(); if (const auto *EC = dyn_cast(E)) @@ -228,20 +184,20 @@ if (const auto *UnOp = dyn_cast(E)) { if (UnOp->getOpcode() == UO_LNot) { if (needsNullPtrComparison(UnOp->getSubExpr())) - return compareExpressionToNullPtr(Result, UnOp->getSubExpr(), true); + return compareExpressionToNullPtr(Context, UnOp->getSubExpr(), true); if (needsZeroComparison(UnOp->getSubExpr())) - return compareExpressionToZero(Result, UnOp->getSubExpr(), true); + return compareExpressionToZero(Context, UnOp->getSubExpr(), true); - return replacementExpression(Result, false, UnOp->getSubExpr()); + return replacementExpression(Context, false, UnOp->getSubExpr()); } } if (needsNullPtrComparison(E)) - return compareExpressionToNullPtr(Result, E, false); + return compareExpressionToNullPtr(Context, E, false); if (needsZeroComparison(E)) - return compareExpressionToZero(Result, E, false); + return compareExpressionToZero(Context, E, false); StringRef NegatedOperator; const Expr *LHS = nullptr; @@ -258,20 +214,20 @@ } } if (!NegatedOperator.empty() && LHS && RHS) - return (asBool((getText(Result, *LHS) + " " + NegatedOperator + " " + - getText(Result, *RHS)) + return (asBool((getText(Context, *LHS) + " " + NegatedOperator + " " + + getText(Context, *RHS)) .str(), NeedsStaticCast)); - StringRef Text = getText(Result, *E); + StringRef Text = getText(Context, *E); if (!NeedsStaticCast && needsParensAfterUnaryNegation(E)) return ("!(" + Text + ")").str(); if (needsNullPtrComparison(E)) - return compareExpressionToNullPtr(Result, E, false); + return compareExpressionToNullPtr(Context, E, false); if (needsZeroComparison(E)) - return compareExpressionToZero(Result, E, false); + return compareExpressionToZero(Context, E, false); return ("!" + asBool(Text, NeedsStaticCast)); } @@ -279,20 +235,20 @@ if (const auto *UnOp = dyn_cast(E)) { if (UnOp->getOpcode() == UO_LNot) { if (needsNullPtrComparison(UnOp->getSubExpr())) - return compareExpressionToNullPtr(Result, UnOp->getSubExpr(), false); + return compareExpressionToNullPtr(Context, UnOp->getSubExpr(), false); if (needsZeroComparison(UnOp->getSubExpr())) - return compareExpressionToZero(Result, UnOp->getSubExpr(), false); + return compareExpressionToZero(Context, UnOp->getSubExpr(), false); } } if (needsNullPtrComparison(E)) - return compareExpressionToNullPtr(Result, E, true); + return compareExpressionToNullPtr(Context, E, true); if (needsZeroComparison(E)) - return compareExpressionToZero(Result, E, true); + return compareExpressionToZero(Context, E, true); - return asBool(getText(Result, *E), NeedsStaticCast); + return asBool(getText(Context, *E), NeedsStaticCast); } static const Expr *stmtReturnsBool(const ReturnStmt *Ret, bool Negated) { @@ -330,14 +286,14 @@ return nullptr; } -static bool containsDiscardedTokens(const MatchFinder::MatchResult &Result, +static bool containsDiscardedTokens(const ASTContext &Context, CharSourceRange CharRange) { std::string ReplacementText = - Lexer::getSourceText(CharRange, *Result.SourceManager, - Result.Context->getLangOpts()) + Lexer::getSourceText(CharRange, Context.getSourceManager(), + Context.getLangOpts()) .str(); - Lexer Lex(CharRange.getBegin(), Result.Context->getLangOpts(), - ReplacementText.data(), ReplacementText.data(), + Lexer Lex(CharRange.getBegin(), Context.getLangOpts(), ReplacementText.data(), + ReplacementText.data(), ReplacementText.data() + ReplacementText.size()); Lex.SetCommentRetentionState(true); @@ -352,18 +308,147 @@ class SimplifyBooleanExprCheck::Visitor : public RecursiveASTVisitor { public: - Visitor(SimplifyBooleanExprCheck *Check, - const MatchFinder::MatchResult &Result) - : Check(Check), Result(Result) {} + Visitor(SimplifyBooleanExprCheck *Check, ASTContext &Context) + : Check(Check), Context(Context) {} + + bool traverse() { return TraverseAST(Context); } bool VisitBinaryOperator(const BinaryOperator *Op) const { - Check->reportBinOp(Result, Op); + Check->reportBinOp(Context, Op); + return true; + } + + static Optional getAsBoolLiteral(const Expr *E, bool FilterMacro) { + if (const auto *Bool = dyn_cast(E)) { + if (FilterMacro && Bool->getBeginLoc().isMacroID()) + return llvm::None; + return Bool->getValue(); + } + if (const auto *UOp = dyn_cast(E)) { + if (FilterMacro && UOp->getBeginLoc().isMacroID()) + return None; + if (UOp->getOpcode() == UO_LNot) + if (Optional Res = getAsBoolLiteral( + UOp->getSubExpr()->IgnoreImplicit(), FilterMacro)) + return !*Res; + } + return llvm::None; + } + + static llvm::PointerIntPair + parseReturnLiteralBool(const Stmt *S) { + const auto *RS = dyn_cast(S); + if (!RS || !RS->getRetValue()) + return {}; + if (auto Ret = + getAsBoolLiteral(RS->getRetValue()->IgnoreImplicit(), false)) { + return {RS->getRetValue(), *Ret}; + } + return {}; + } + + template + static auto checkSingleStatement(Stmt *S, Functor F) -> decltype(F(S)) { + if (auto *CS = dyn_cast(S)) { + if (CS->size() == 1) + return F(CS->body_front()); + return {}; + } + return F(S); + } + + bool doesIfHaveIfParent(const IfStmt *If) { + auto Parents = Context.getParents(*If); + if (Parents.empty()) + return false; + return Parents[0].get() != nullptr; + } + + bool VisitIfStmt(IfStmt *If) { + Expr *Cond = If->getCond()->IgnoreImplicit(); + if (auto Bool = getAsBoolLiteral(Cond, true)) { + if (*Bool) + Check->replaceWithThenStatement(Context, If, Cond); + else + Check->replaceWithElseStatement(Context, If, Cond); + } + + if (If->getElse()) { + auto ThenReturnBool = + checkSingleStatement(If->getThen(), parseReturnLiteralBool); + if (ThenReturnBool.getPointer()) { + auto ElseReturnBool = + checkSingleStatement(If->getElse(), parseReturnLiteralBool); + if (ElseReturnBool.getPointer() && + ThenReturnBool.getInt() != ElseReturnBool.getInt()) { + if (Check->ChainedConditionalReturn || !doesIfHaveIfParent(If)) { + Check->replaceWithReturnCondition(Context, If, + ThenReturnBool.getPointer(), + !ThenReturnBool.getInt()); + } + } + } + } + return true; + } + + bool VisitConditionalOperator(ConditionalOperator *Cond) { + if (auto Then = + getAsBoolLiteral(Cond->getTrueExpr()->IgnoreImplicit(), false)) { + if (auto Else = + getAsBoolLiteral(Cond->getFalseExpr()->IgnoreImplicit(), false)) { + if (*Then != *Else) + Check->replaceWithCondition(Context, Cond, *Else); + } + } + return true; + } + + bool VisitCompoundStmt(CompoundStmt *CS) { + if (CS->size() < 2) + return true; + for (auto Second = CS->body_rbegin(), First = std::next(Second), + End = CS->body_rend(); + First != End; ++Second, ++First) { + auto RetStmt = parseReturnLiteralBool(*Second); + if (!RetStmt.getPointer()) + continue; + + if (auto *If = dyn_cast(*First)) { + auto ThenReturnBool = + checkSingleStatement(If->getThen(), parseReturnLiteralBool); + if (ThenReturnBool.getPointer() && + ThenReturnBool.getInt() != RetStmt.getInt()) { + if (Check->ChainedConditionalReturn || + (If->getElse() == nullptr && !doesIfHaveIfParent(If))) { + Check->replaceCompoundReturnWithCondition( + Context, CS, cast(*Second), RetStmt.getInt()); + } + } + } else if (isa(*First)) { + Stmt *SubStmt = + isa(*First) ? cast(*First)->getSubStmt() + : isa(*First) ? cast(*First)->getSubStmt() + : cast(*First)->getSubStmt(); + if (auto *SubIf = dyn_cast(SubStmt)) { + if (!SubIf->getElse()) { + auto ThenReturnBool = + checkSingleStatement(SubIf->getThen(), parseReturnLiteralBool); + if (ThenReturnBool.getPointer() && + ThenReturnBool.getInt() != RetStmt.getInt()) { + Check->replaceCompoundReturnWithCondition( + Context, cast(*Second), RetStmt.getInt(), SubIf); + } + } + } + } + } return true; } private: SimplifyBooleanExprCheck *Check; - const MatchFinder::MatchResult &Result; + ASTContext &Context; }; SimplifyBooleanExprCheck::SimplifyBooleanExprCheck(StringRef Name, @@ -387,8 +472,8 @@ return false; } -void SimplifyBooleanExprCheck::reportBinOp( - const MatchFinder::MatchResult &Result, const BinaryOperator *Op) { +void SimplifyBooleanExprCheck::reportBinOp(const ASTContext &Context, + const BinaryOperator *Op) { const auto *LHS = Op->getLHS()->IgnoreParenImpCasts(); const auto *RHS = Op->getRHS()->IgnoreParenImpCasts(); @@ -410,12 +495,12 @@ bool BoolValue = Bool->getValue(); - auto ReplaceWithExpression = [this, &Result, LHS, RHS, + auto ReplaceWithExpression = [this, &Context, LHS, RHS, Bool](const Expr *ReplaceWith, bool Negated) { std::string Replacement = - replacementExpression(Result, Negated, ReplaceWith); + replacementExpression(Context, Negated, ReplaceWith); SourceRange Range(LHS->getBeginLoc(), RHS->getEndLoc()); - issueDiag(Result, Bool->getBeginLoc(), SimplifyOperatorDiagnostic, Range, + issueDiag(Context, Bool->getBeginLoc(), SimplifyOperatorDiagnostic, Range, Replacement); }; @@ -449,39 +534,6 @@ } } -void SimplifyBooleanExprCheck::matchBoolCondition(MatchFinder *Finder, - bool Value, - StringRef BooleanId) { - Finder->addMatcher( - ifStmt(hasCondition(literalOrNegatedBool(Value).bind(BooleanId))) - .bind(IfStmtId), - this); -} - -void SimplifyBooleanExprCheck::matchTernaryResult(MatchFinder *Finder, - bool Value, StringRef Id) { - Finder->addMatcher( - conditionalOperator(hasTrueExpression(literalOrNegatedBool(Value)), - hasFalseExpression(literalOrNegatedBool(!Value))) - .bind(Id), - this); -} - -void SimplifyBooleanExprCheck::matchIfReturnsBool(MatchFinder *Finder, - bool Value, StringRef Id) { - if (ChainedConditionalReturn) - Finder->addMatcher(ifStmt(hasThen(returnsBool(Value, ThenLiteralId)), - hasElse(returnsBool(!Value))) - .bind(Id), - this); - else - Finder->addMatcher(ifStmt(unless(hasParent(ifStmt())), - hasThen(returnsBool(Value, ThenLiteralId)), - hasElse(returnsBool(!Value))) - .bind(Id), - this); -} - void SimplifyBooleanExprCheck::matchIfAssignsBool(MatchFinder *Finder, bool Value, StringRef Id) { auto VarAssign = declRefExpr(hasDeclaration(decl().bind(IfAssignVarId))); @@ -508,68 +560,6 @@ this); } -static internal::Matcher ifReturnValue(bool Value) { - return ifStmt(hasThen(returnsBool(Value)), unless(hasElse(stmt()))) - .bind(CompoundIfId); -} - -static internal::Matcher returnNotValue(bool Value) { - return returnStmt(has(literalOrNegatedBool(!Value))).bind(CompoundReturnId); -} - -void SimplifyBooleanExprCheck::matchCompoundIfReturnsBool(MatchFinder *Finder, - bool Value, - StringRef Id) { - if (ChainedConditionalReturn) - Finder->addMatcher( - compoundStmt(hasSubstatementSequence(ifReturnValue(Value), - returnNotValue(Value))) - .bind(Id), - this); - else - Finder->addMatcher( - compoundStmt(hasSubstatementSequence(ifStmt(hasThen(returnsBool(Value)), - unless(hasElse(stmt())), - unless(hasParent(ifStmt()))) - .bind(CompoundIfId), - returnNotValue(Value))) - .bind(Id), - this); -} - -void SimplifyBooleanExprCheck::matchCaseIfReturnsBool(MatchFinder *Finder, - bool Value, - StringRef Id) { - internal::Matcher CaseStmt = - caseStmt(hasSubstatement(ifReturnValue(Value))).bind(CaseId); - internal::Matcher CompoundStmt = - compoundStmt(hasSubstatementSequence(CaseStmt, returnNotValue(Value))) - .bind(Id); - Finder->addMatcher(switchStmt(has(CompoundStmt)), this); -} - -void SimplifyBooleanExprCheck::matchDefaultIfReturnsBool(MatchFinder *Finder, - bool Value, - StringRef Id) { - internal::Matcher DefaultStmt = - defaultStmt(hasSubstatement(ifReturnValue(Value))).bind(DefaultId); - internal::Matcher CompoundStmt = - compoundStmt(hasSubstatementSequence(DefaultStmt, returnNotValue(Value))) - .bind(Id); - Finder->addMatcher(switchStmt(has(CompoundStmt)), this); -} - -void SimplifyBooleanExprCheck::matchLabelIfReturnsBool(MatchFinder *Finder, - bool Value, - StringRef Id) { - internal::Matcher LabelStmt = - labelStmt(hasSubstatement(ifReturnValue(Value))).bind(LabelId); - internal::Matcher CompoundStmt = - compoundStmt(hasSubstatementSequence(LabelStmt, returnNotValue(Value))) - .bind(Id); - Finder->addMatcher(CompoundStmt, this); -} - void SimplifyBooleanExprCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) { Options.store(Opts, "ChainedConditionalReturn", ChainedConditionalReturn); Options.store(Opts, "ChainedConditionalAssignment", @@ -579,135 +569,77 @@ void SimplifyBooleanExprCheck::registerMatchers(MatchFinder *Finder) { Finder->addMatcher(translationUnitDecl().bind("top"), this); - matchBoolCondition(Finder, true, ConditionThenStmtId); - matchBoolCondition(Finder, false, ConditionElseStmtId); - - matchTernaryResult(Finder, true, TernaryId); - matchTernaryResult(Finder, false, TernaryNegatedId); - - matchIfReturnsBool(Finder, true, IfReturnsBoolId); - matchIfReturnsBool(Finder, false, IfReturnsNotBoolId); - matchIfAssignsBool(Finder, true, IfAssignBoolId); matchIfAssignsBool(Finder, false, IfAssignNotBoolId); - - matchCompoundIfReturnsBool(Finder, true, CompoundBoolId); - matchCompoundIfReturnsBool(Finder, false, CompoundNotBoolId); - - matchCaseIfReturnsBool(Finder, true, CaseCompoundBoolId); - matchCaseIfReturnsBool(Finder, false, CaseCompoundNotBoolId); - - matchDefaultIfReturnsBool(Finder, true, DefaultCompoundBoolId); - matchDefaultIfReturnsBool(Finder, false, DefaultCompoundNotBoolId); - - matchLabelIfReturnsBool(Finder, true, LabelCompoundBoolId); - matchLabelIfReturnsBool(Finder, false, LabelCompoundNotBoolId); } void SimplifyBooleanExprCheck::check(const MatchFinder::MatchResult &Result) { if (Result.Nodes.getNodeAs("top")) - Visitor(this, Result).TraverseAST(*Result.Context); - else if (const Expr *TrueConditionRemoved = - getBoolLiteral(Result, ConditionThenStmtId)) - replaceWithThenStatement(Result, TrueConditionRemoved); - else if (const Expr *FalseConditionRemoved = - getBoolLiteral(Result, ConditionElseStmtId)) - replaceWithElseStatement(Result, FalseConditionRemoved); - else if (const auto *Ternary = - Result.Nodes.getNodeAs(TernaryId)) - replaceWithCondition(Result, Ternary, false); - else if (const auto *TernaryNegated = - Result.Nodes.getNodeAs(TernaryNegatedId)) - replaceWithCondition(Result, TernaryNegated, true); - else if (const auto *If = Result.Nodes.getNodeAs(IfReturnsBoolId)) - replaceWithReturnCondition(Result, If, false); - else if (const auto *IfNot = - Result.Nodes.getNodeAs(IfReturnsNotBoolId)) - replaceWithReturnCondition(Result, IfNot, true); + Visitor(this, *Result.Context).traverse(); else if (const auto *IfAssign = Result.Nodes.getNodeAs(IfAssignBoolId)) replaceWithAssignment(Result, IfAssign, false); else if (const auto *IfAssignNot = Result.Nodes.getNodeAs(IfAssignNotBoolId)) replaceWithAssignment(Result, IfAssignNot, true); - else if (const auto *Compound = - Result.Nodes.getNodeAs(CompoundBoolId)) - replaceCompoundReturnWithCondition(Result, Compound, false); - else if (const auto *CompoundNot = - Result.Nodes.getNodeAs(CompoundNotBoolId)) - replaceCompoundReturnWithCondition(Result, CompoundNot, true); - else if (Result.Nodes.getNodeAs(CaseCompoundBoolId)) - replaceCaseCompoundReturnWithCondition(Result, false); - else if (Result.Nodes.getNodeAs(CaseCompoundNotBoolId)) - replaceCaseCompoundReturnWithCondition(Result, true); - else if (Result.Nodes.getNodeAs(DefaultCompoundBoolId)) - replaceDefaultCompoundReturnWithCondition(Result, false); - else if (Result.Nodes.getNodeAs(DefaultCompoundNotBoolId)) - replaceDefaultCompoundReturnWithCondition(Result, true); - else if (Result.Nodes.getNodeAs(LabelCompoundBoolId)) - replaceLabelCompoundReturnWithCondition(Result, false); - else if (Result.Nodes.getNodeAs(LabelCompoundNotBoolId)) - replaceLabelCompoundReturnWithCondition(Result, true); - else if (const auto TU = Result.Nodes.getNodeAs("top")) - Visitor(this, Result).TraverseDecl(const_cast(TU)); -} - -void SimplifyBooleanExprCheck::issueDiag(const MatchFinder::MatchResult &Result, +} + +void SimplifyBooleanExprCheck::issueDiag(const ASTContext &Context, SourceLocation Loc, StringRef Description, SourceRange ReplacementRange, StringRef Replacement) { CharSourceRange CharRange = Lexer::makeFileCharRange(CharSourceRange::getTokenRange(ReplacementRange), - *Result.SourceManager, getLangOpts()); + Context.getSourceManager(), getLangOpts()); DiagnosticBuilder Diag = diag(Loc, Description); - if (!containsDiscardedTokens(Result, CharRange)) + if (!containsDiscardedTokens(Context, CharRange)) Diag << FixItHint::CreateReplacement(CharRange, Replacement); } void SimplifyBooleanExprCheck::replaceWithThenStatement( - const MatchFinder::MatchResult &Result, const Expr *BoolLiteral) { - const auto *IfStatement = Result.Nodes.getNodeAs(IfStmtId); - issueDiag(Result, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic, + const ASTContext &Context, const IfStmt *IfStatement, + const Expr *BoolLiteral) { + issueDiag(Context, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic, IfStatement->getSourceRange(), - getText(Result, *IfStatement->getThen())); + getText(Context, *IfStatement->getThen())); } void SimplifyBooleanExprCheck::replaceWithElseStatement( - const MatchFinder::MatchResult &Result, const Expr *BoolLiteral) { - const auto *IfStatement = Result.Nodes.getNodeAs(IfStmtId); + const ASTContext &Context, const IfStmt *IfStatement, + const Expr *BoolLiteral) { const Stmt *ElseStatement = IfStatement->getElse(); - issueDiag(Result, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic, + issueDiag(Context, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic, IfStatement->getSourceRange(), - ElseStatement ? getText(Result, *ElseStatement) : ""); + ElseStatement ? getText(Context, *ElseStatement) : ""); } void SimplifyBooleanExprCheck::replaceWithCondition( - const MatchFinder::MatchResult &Result, const ConditionalOperator *Ternary, + const ASTContext &Context, const ConditionalOperator *Ternary, bool Negated) { std::string Replacement = - replacementExpression(Result, Negated, Ternary->getCond()); - issueDiag(Result, Ternary->getTrueExpr()->getBeginLoc(), + replacementExpression(Context, Negated, Ternary->getCond()); + issueDiag(Context, Ternary->getTrueExpr()->getBeginLoc(), "redundant boolean literal in ternary expression result", Ternary->getSourceRange(), Replacement); } void SimplifyBooleanExprCheck::replaceWithReturnCondition( - const MatchFinder::MatchResult &Result, const IfStmt *If, bool Negated) { + const ASTContext &Context, const IfStmt *If, const Expr *BoolLiteral, + bool Negated) { StringRef Terminator = isa(If->getElse()) ? ";" : ""; - std::string Condition = replacementExpression(Result, Negated, If->getCond()); + std::string Condition = + replacementExpression(Context, Negated, If->getCond()); std::string Replacement = ("return " + Condition + Terminator).str(); - SourceLocation Start = - Result.Nodes.getNodeAs(ThenLiteralId)->getBeginLoc(); - issueDiag(Result, Start, SimplifyConditionalReturnDiagnostic, + SourceLocation Start = BoolLiteral->getBeginLoc(); + issueDiag(Context, Start, SimplifyConditionalReturnDiagnostic, If->getSourceRange(), Replacement); } void SimplifyBooleanExprCheck::replaceCompoundReturnWithCondition( - const MatchFinder::MatchResult &Result, const CompoundStmt *Compound, - bool Negated) { - const auto *Ret = Result.Nodes.getNodeAs(CompoundReturnId); + const ASTContext &Context, const CompoundStmt *Compound, + const ReturnStmt *Ret, bool Negated) { // Scan through the CompoundStmt to look for a chained-if construct. const IfStmt *BeforeIf = nullptr; @@ -722,9 +654,10 @@ continue; std::string Replacement = - "return " + replacementExpression(Result, Negated, If->getCond()); + "return " + + replacementExpression(Context, Negated, If->getCond()); issueDiag( - Result, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic, + Context, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic, SourceRange(If->getBeginLoc(), Ret->getEndLoc()), Replacement); return; } @@ -738,51 +671,29 @@ } void SimplifyBooleanExprCheck::replaceCompoundReturnWithCondition( - const MatchFinder::MatchResult &Result, bool Negated, const IfStmt *If) { + const ASTContext &Context, const ReturnStmt *Ret, bool Negated, + const IfStmt *If) { const auto *Lit = stmtReturnsBool(If, Negated); - const auto *Ret = Result.Nodes.getNodeAs(CompoundReturnId); const std::string Replacement = - "return " + replacementExpression(Result, Negated, If->getCond()); - issueDiag(Result, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic, + "return " + replacementExpression(Context, Negated, If->getCond()); + issueDiag(Context, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic, SourceRange(If->getBeginLoc(), Ret->getEndLoc()), Replacement); } -void SimplifyBooleanExprCheck::replaceCaseCompoundReturnWithCondition( - const MatchFinder::MatchResult &Result, bool Negated) { - const auto *CaseDefault = Result.Nodes.getNodeAs(CaseId); - const auto *If = cast(CaseDefault->getSubStmt()); - replaceCompoundReturnWithCondition(Result, Negated, If); -} - -void SimplifyBooleanExprCheck::replaceDefaultCompoundReturnWithCondition( - const MatchFinder::MatchResult &Result, bool Negated) { - const SwitchCase *CaseDefault = - Result.Nodes.getNodeAs(DefaultId); - const auto *If = cast(CaseDefault->getSubStmt()); - replaceCompoundReturnWithCondition(Result, Negated, If); -} - -void SimplifyBooleanExprCheck::replaceLabelCompoundReturnWithCondition( - const MatchFinder::MatchResult &Result, bool Negated) { - const auto *Label = Result.Nodes.getNodeAs(LabelId); - const auto *If = cast(Label->getSubStmt()); - replaceCompoundReturnWithCondition(Result, Negated, If); -} - void SimplifyBooleanExprCheck::replaceWithAssignment( const MatchFinder::MatchResult &Result, const IfStmt *IfAssign, bool Negated) { SourceRange Range = IfAssign->getSourceRange(); - StringRef VariableName = - getText(Result, *Result.Nodes.getNodeAs(IfAssignVariableId)); + StringRef VariableName = getText( + *Result.Context, *Result.Nodes.getNodeAs(IfAssignVariableId)); StringRef Terminator = isa(IfAssign->getElse()) ? ";" : ""; std::string Condition = - replacementExpression(Result, Negated, IfAssign->getCond()); + replacementExpression(*Result.Context, Negated, IfAssign->getCond()); std::string Replacement = (VariableName + " = " + Condition + Terminator).str(); SourceLocation Location = Result.Nodes.getNodeAs(IfAssignLocId)->getBeginLoc(); - issueDiag(Result, Location, + issueDiag(*Result.Context, Location, "redundant boolean literal in conditional assignment", Range, Replacement); }