diff --git a/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h b/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h --- a/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h +++ b/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h @@ -51,11 +51,11 @@ void replaceWithThenStatement(const ast_matchers::MatchFinder::MatchResult &Result, - const CXXBoolLiteralExpr *BoolLiteral); + const Expr *BoolLiteral); void replaceWithElseStatement(const ast_matchers::MatchFinder::MatchResult &Result, - const CXXBoolLiteralExpr *FalseConditionRemoved); + const Expr *FalseConditionRemoved); void replaceWithCondition(const ast_matchers::MatchFinder::MatchResult &Result, diff --git a/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp b/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp --- a/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp +++ b/clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp @@ -58,16 +58,28 @@ const char SimplifyConditionalReturnDiagnostic[] = "redundant boolean literal in conditional return statement"; -const CXXBoolLiteralExpr *getBoolLiteral(const MatchFinder::MatchResult &Result, - StringRef Id) { - const auto *Literal = Result.Nodes.getNodeAs(Id); - return (Literal && Literal->getBeginLoc().isMacroID()) ? nullptr : Literal; +const Expr *getBoolLiteral(const MatchFinder::MatchResult &Result, + StringRef Id) { + if (const Expr *Literal = Result.Nodes.getNodeAs(Id)) + return Literal->getBeginLoc().isMacroID() ? nullptr : Literal; + if (const auto *Negated = Result.Nodes.getNodeAs(Id)) { + if (Negated->getOpcode() == UO_LNot && + isa(Negated->getSubExpr())) + return Negated->getBeginLoc().isMacroID() ? nullptr : Negated; + } + return nullptr; +} + +internal::BindableMatcher literalOrNegatedBool(bool Value) { + return expr(anyOf(cxxBoolLiteral(equals(Value)), + unaryOperator(hasUnaryOperand(ignoringParenImpCasts( + cxxBoolLiteral(equals(!Value)))), + hasOperatorName("!")))); } internal::Matcher returnsBool(bool Value, StringRef Id = "ignored") { - auto SimpleReturnsBool = - returnStmt(has(cxxBoolLiteral(equals(Value)).bind(Id))) - .bind("returns-bool"); + auto SimpleReturnsBool = returnStmt(has(literalOrNegatedBool(Value).bind(Id))) + .bind("returns-bool"); return anyOf(SimpleReturnsBool, compoundStmt(statementCountIs(1), has(SimpleReturnsBool))); } @@ -269,16 +281,25 @@ return asBool(getText(Result, *E), NeedsStaticCast); } -const CXXBoolLiteralExpr *stmtReturnsBool(const ReturnStmt *Ret, bool Negated) { +const Expr *stmtReturnsBool(const ReturnStmt *Ret, bool Negated) { if (const auto *Bool = dyn_cast(Ret->getRetValue())) { if (Bool->getValue() == !Negated) return Bool; } + if (const auto *Unary = dyn_cast(Ret->getRetValue())) { + if (Unary->getOpcode() == UO_LNot) { + if (const auto *Bool = + dyn_cast(Unary->getSubExpr())) { + if (Bool->getValue() == Negated) + return Bool; + } + } + } return nullptr; } -const CXXBoolLiteralExpr *stmtReturnsBool(const IfStmt *IfRet, bool Negated) { +const Expr *stmtReturnsBool(const IfStmt *IfRet, bool Negated) { if (IfRet->getElse() != nullptr) return nullptr; @@ -423,7 +444,7 @@ StringRef BooleanId) { Finder->addMatcher( ifStmt(unless(isInTemplateInstantiation()), - hasCondition(cxxBoolLiteral(equals(Value)).bind(BooleanId))) + hasCondition(literalOrNegatedBool(Value).bind(BooleanId))) .bind(IfStmtId), this); } @@ -433,8 +454,8 @@ StringRef TernaryId) { Finder->addMatcher( conditionalOperator(unless(isInTemplateInstantiation()), - hasTrueExpression(cxxBoolLiteral(equals(Value))), - hasFalseExpression(cxxBoolLiteral(equals(!Value)))) + hasTrueExpression(literalOrNegatedBool(Value)), + hasFalseExpression(literalOrNegatedBool(!Value))) .bind(TernaryId), this); } @@ -465,12 +486,12 @@ auto SimpleThen = binaryOperator(hasOperatorName("="), hasLHS(anyOf(VarAssign, MemAssign)), hasLHS(expr().bind(IfAssignVariableId)), - hasRHS(cxxBoolLiteral(equals(Value)).bind(IfAssignLocId))); + hasRHS(literalOrNegatedBool(Value).bind(IfAssignLocId))); auto Then = anyOf(SimpleThen, compoundStmt(statementCountIs(1), hasAnySubstatement(SimpleThen))); auto SimpleElse = binaryOperator(hasOperatorName("="), hasLHS(anyOf(VarRef, MemRef)), - hasRHS(cxxBoolLiteral(equals(!Value)))); + hasRHS(literalOrNegatedBool(!Value))); auto Else = anyOf(SimpleElse, compoundStmt(statementCountIs(1), hasAnySubstatement(SimpleElse))); if (ChainedConditionalAssignment) @@ -495,7 +516,7 @@ hasAnySubstatement( ifStmt(hasThen(returnsBool(Value)), unless(hasElse(stmt())))), hasAnySubstatement(returnStmt(has(ignoringParenImpCasts( - cxxBoolLiteral(equals(!Value))))) + literalOrNegatedBool(!Value)))) .bind(CompoundReturnId))) .bind(Id), this); @@ -529,10 +550,10 @@ void SimplifyBooleanExprCheck::check(const MatchFinder::MatchResult &Result) { if (Result.Nodes.getNodeAs("top")) Visitor(this, Result).TraverseAST(*Result.Context); - else if (const CXXBoolLiteralExpr *TrueConditionRemoved = + else if (const Expr *TrueConditionRemoved = getBoolLiteral(Result, ConditionThenStmtId)) replaceWithThenStatement(Result, TrueConditionRemoved); - else if (const CXXBoolLiteralExpr *FalseConditionRemoved = + else if (const Expr *FalseConditionRemoved = getBoolLiteral(Result, ConditionElseStmtId)) replaceWithElseStatement(Result, FalseConditionRemoved); else if (const auto *Ternary = @@ -574,8 +595,7 @@ } void SimplifyBooleanExprCheck::replaceWithThenStatement( - const MatchFinder::MatchResult &Result, - const CXXBoolLiteralExpr *TrueConditionRemoved) { + const MatchFinder::MatchResult &Result, const Expr *TrueConditionRemoved) { const auto *IfStatement = Result.Nodes.getNodeAs(IfStmtId); issueDiag(Result, TrueConditionRemoved->getBeginLoc(), SimplifyConditionDiagnostic, IfStatement->getSourceRange(), @@ -583,8 +603,7 @@ } void SimplifyBooleanExprCheck::replaceWithElseStatement( - const MatchFinder::MatchResult &Result, - const CXXBoolLiteralExpr *FalseConditionRemoved) { + const MatchFinder::MatchResult &Result, const Expr *FalseConditionRemoved) { const auto *IfStatement = Result.Nodes.getNodeAs(IfStmtId); const Stmt *ElseStatement = IfStatement->getElse(); issueDiag(Result, FalseConditionRemoved->getBeginLoc(), @@ -631,7 +650,7 @@ for (++After; After != Compound->body_end() && *Current != Ret; ++Current, ++After) { if (const auto *If = dyn_cast(*Current)) { - if (const CXXBoolLiteralExpr *Lit = stmtReturnsBool(If, Negated)) { + if (const Expr *Lit = stmtReturnsBool(If, Negated)) { if (*After == Ret) { if (!ChainedConditionalReturn && BeforeIf) continue; diff --git a/clang-tools-extra/test/clang-tidy/checkers/readability-simplify-bool-expr.cpp b/clang-tools-extra/test/clang-tidy/checkers/readability-simplify-bool-expr.cpp --- a/clang-tools-extra/test/clang-tidy/checkers/readability-simplify-bool-expr.cpp +++ b/clang-tools-extra/test/clang-tidy/checkers/readability-simplify-bool-expr.cpp @@ -98,6 +98,46 @@ // CHECK-FIXES-NEXT: {{^ i = 11;$}} } +void if_with_negated_bool_condition() { + int i = 10; + if (!true) { + i = 11; + } else { + i = 12; + } + i = 13; + // CHECK-MESSAGES: :[[@LINE-6]]:7: warning: {{.*}} in if statement condition + // CHECK-FIXES: {{^ int i = 10;$}} + // CHECK-FIXES-NEXT: {{^ {$}} + // CHECK-FIXES-NEXT: {{^ i = 12;$}} + // CHECK-FIXES-NEXT: {{^ }$}} + // CHECK-FIXES-NEXT: {{^ i = 13;$}} + + i = 14; + if (!false) { + i = 15; + } else { + i = 16; + } + i = 17; + // CHECK-MESSAGES: :[[@LINE-6]]:7: warning: {{.*}} in if statement condition + // CHECK-FIXES: {{^ i = 14;$}} + // CHECK-FIXES-NEXT: {{^ {$}} + // CHECK-FIXES-NEXT: {{^ i = 15;$}} + // CHECK-FIXES-NEXT: {{^ }$}} + // CHECK-FIXES-NEXT: {{^ i = 17;$}} + + i = 18; + if (!true) { + i = 19; + } + i = 20; + // CHECK-MESSAGES: :[[@LINE-4]]:7: warning: {{.*}} in if statement condition + // CHECK-FIXES: {{^ i = 18;$}} + // CHECK-FIXES-NEXT: {{^ $}} + // CHECK-FIXES-NEXT: {{^ i = 20;$}} +} + void operator_equals() { int i = 0; bool b1 = (i > 2);