diff --git a/clang-tools-extra/clang-rewrite/ClangRewrite.cpp b/clang-tools-extra/clang-rewrite/ClangRewrite.cpp --- a/clang-tools-extra/clang-rewrite/ClangRewrite.cpp +++ b/clang-tools-extra/clang-rewrite/ClangRewrite.cpp @@ -107,13 +107,11 @@ InsertPostmatchCallback postmatch_callback; ReplaceCallback replace_callback; MatcherGenCallback matcher_callback; - // ReplaceCallback2 r2d2; inst_finder.addMatcher(insert_before_match, &prematch_callback); inst_finder.addMatcher(insert_after_match, &postmatch_callback); inst_finder.addMatcher(replace_match, &replace_callback); inst_finder.addMatcher(matcher, &matcher_callback); - // inst_finder.addMatcher(replace2, &r2d2); // MatcherWrapper* m = new MatcherWrapper(rettest, "returns_test", // "test", diff --git a/clang-tools-extra/clang-rewrite/MatcherGenCallback.h b/clang-tools-extra/clang-rewrite/MatcherGenCallback.h --- a/clang-tools-extra/clang-rewrite/MatcherGenCallback.h +++ b/clang-tools-extra/clang-rewrite/MatcherGenCallback.h @@ -99,7 +99,7 @@ if (root == nullptr) { root = temp; current = root; - // bind_to("match"); + bind_to("match"); } else { current->add_child(current, temp); @@ -127,7 +127,10 @@ functionDecl(allOf( hasAttr(attr::Matcher), hasBody(compoundStmt( - hasAnySubstatement(compoundStmt(anything()).bind("body")) + hasAnySubstatement(attributedStmt(allOf( + isAttr(attr::MatcherBlock), + hasSubStmt(compoundStmt(anything()).bind("body")) + ))) )) )).bind("matcher"); diff --git a/clang-tools-extra/clang-rewrite/NewCodeCallback.h b/clang-tools-extra/clang-rewrite/NewCodeCallback.h --- a/clang-tools-extra/clang-rewrite/NewCodeCallback.h +++ b/clang-tools-extra/clang-rewrite/NewCodeCallback.h @@ -57,13 +57,39 @@ // .bind("replace"); DeclarationMatcher insert_before_match = - functionDecl(hasAttr(attr::InsertCodeBefore)).bind("insert_before_match"); + functionDecl(allOf( + hasAttr(attr::InsertCodeBefore), + hasBody(compoundStmt( + hasAnySubstatement(attributedStmt(allOf( + isAttr(attr::MatcherBlock), + hasSubStmt(compoundStmt(anything()).bind("body")) + ))) + )) + )).bind("insert_before_match"); DeclarationMatcher insert_after_match = - functionDecl(hasAttr(attr::InsertCodeAfter)).bind("insert_after_match"); + functionDecl(allOf( + hasAttr(attr::InsertCodeAfter), + hasBody(compoundStmt( + hasAnySubstatement(attributedStmt(allOf( + isAttr(attr::MatcherBlock), + hasSubStmt(compoundStmt(anything()).bind("body")) + ))) + )) + )).bind("insert_after_match"); DeclarationMatcher replace_match = - functionDecl(hasAttr(attr::ReplaceCode)).bind("replace"); + functionDecl(allOf( + hasAttr(attr::ReplaceCode), + hasBody(compoundStmt( + hasAnySubstatement(attributedStmt(allOf( + isAttr(attr::MatcherBlock), + hasSubStmt(compoundStmt(anything()).bind("body")) + ))) + )) + )).bind("replace"); + + std::vector all_actions; @@ -129,70 +155,69 @@ } // grab function body as new code - Stmt* new_code = nullptr; - if (func->hasBody()) { - new_code = func->getBody(); - printf("function body!!!\n"); - new_code->dump(); + const CompoundStmt* body = result.Nodes.getNodeAs("body"); + if (!body || !context->getSourceManager().isWrittenInMainFile(body->getBeginLoc())) { + printf("ERROR: invalid body\n"); + return; + } + printf("function body\n"); + body->dump(); + + FullSourceLoc body_begin; + FullSourceLoc body_end; + if (!body->body_empty()) { + body_begin = context->getFullLoc(body->body_front()->getBeginLoc()); + + // go to end of line; stmts don't work, gotta lex to the end of the line + SourceLocation eol = Lexer::getLocForEndOfToken( + body->body_back()->getBeginLoc(), 0, context->getSourceManager(), + context->getLangOpts()); + Optional tok = Lexer::findNextToken( + eol, context->getSourceManager(), context->getLangOpts()); + while (tok.hasValue() && tok->isNot(clang::tok::semi)) { + tok = Lexer::findNextToken(eol, context->getSourceManager(), + context->getLangOpts()); + eol = tok->getLocation(); + } + // TODO: this is a hack and we should be smarter about semicolons + if (kind != Replace) { + eol = tok->getEndLoc(); // grab semicolon + } + body_end = context->getFullLoc(eol); } - else { - printf("WARNING: code modification empty\n"); + else { // empty body just use brackets + body_begin = context->getFullLoc(body->getLBracLoc()); + body_end = context->getFullLoc(body->getRBracLoc()); } - // FullSourceLoc body_begin; - // FullSourceLoc body_end; - // if (!new_code->body_empty()) { - // body_begin = context->getFullLoc(new_code->body_front()->getBeginLoc()); - // - // // go to end of line; stmts don't work, gotta lex to the end of the line - // SourceLocation eol = Lexer::getLocForEndOfToken( - // new_code->body_back()->getBeginLoc(), 0, context->getSourceManager(), - // context->getLangOpts()); - // Optional tok = Lexer::findNextToken( - // eol, context->getSourceManager(), context->getLangOpts()); - // while (tok.hasValue() && tok->isNot(clang::tok::semi)) { - // tok = Lexer::findNextToken(eol, context->getSourceManager(), - // context->getLangOpts()); - // eol = tok->getLocation(); - // } - // // TODO: this is a hack and we should be smarter about semicolons - // if (kind != Replace) { - // eol = tok->getEndLoc(); // grab semicolon - // } - // body_end = context->getFullLoc(eol); - // } else { // empty body just use brackets - // body_begin = context->getFullLoc(new_code->getLBracLoc()); - // body_end = context->getFullLoc(new_code->getRBracLoc()); - // } - // - // FileID fid = body_begin.getFileID(); - // unsigned int begin_offset = body_begin.getFileOffset(); - // unsigned int end_offset = body_end.getFileOffset(); - // - // printf("begin offset %u\n", begin_offset); - // printf("end offset %u\n", end_offset); - // printf("array length %u\n", end_offset - begin_offset); - // - // llvm::Optional buff = - // context->getSourceManager().getBufferOrNone(fid); - // - // char *code = new char[end_offset - begin_offset + 1]; - // if (buff.hasValue()) { - // memcpy(code, &(buff->getBufferStart()[begin_offset]), - // (end_offset - begin_offset + 1) * sizeof(char)); - // code[end_offset - begin_offset] = - // '\0'; // force null terminated for Reasons - // printf("code??? %s\n", code); - // } else { - // printf("no buffer :<\n"); - // } - // - // // make action, put in vector of actions - // CodeAction *act = - // new CodeAction(kind, matcher_names, std::string(code), action_name); - // all_actions.push_back(act); - // - // delete[] code; + FileID fid = body_begin.getFileID(); + unsigned int begin_offset = body_begin.getFileOffset(); + unsigned int end_offset = body_end.getFileOffset(); + + printf("begin offset %u\n", begin_offset); + printf("end offset %u\n", end_offset); + printf("array length %u\n", end_offset - begin_offset); + + llvm::Optional buff = + context->getSourceManager().getBufferOrNone(fid); + + char *code = new char[end_offset - begin_offset + 1]; + if (buff.hasValue()) { + memcpy(code, &(buff->getBufferStart()[begin_offset]), + (end_offset - begin_offset + 1) * sizeof(char)); + code[end_offset - begin_offset] = + '\0'; // force null terminated for Reasons + printf("code??? %s\n", code); + } else { + printf("no buffer :<\n"); + } + + // make action, put in vector of actions + CodeAction *act = + new CodeAction(kind, matcher_names, std::string(code), action_name); + all_actions.push_back(act); + + delete[] code; } private: @@ -233,12 +258,5 @@ } }; -class ReplaceCallback2 : public NewCodeCallback { -public: - ReplaceCallback2() { - kind = Replace; - kind_name = "replace"; - } -}; #endif diff --git a/clang-tools-extra/clang-rewrite/RewriteCallback.h b/clang-tools-extra/clang-rewrite/RewriteCallback.h --- a/clang-tools-extra/clang-rewrite/RewriteCallback.h +++ b/clang-tools-extra/clang-rewrite/RewriteCallback.h @@ -137,11 +137,11 @@ unsigned int end_line = end.getSpellingLineNumber(); unsigned int end_col = end.getSpellingColumnNumber(); - if (verbose) { + // if (verbose) { printf("FOUND match for %s at %d:%d - %d:%d\n", matcher->getName().c_str(), begin_line, begin_col, end_line, end_col); - } + // } if (rw.isRewritable(match->getBeginLoc()) && rw.isRewritable(match->getEndLoc())) { diff --git a/clang-tools-extra/clang-rewrite/tests/new.cpp b/clang-tools-extra/clang-rewrite/tests/new.cpp --- a/clang-tools-extra/clang-rewrite/tests/new.cpp +++ b/clang-tools-extra/clang-rewrite/tests/new.cpp @@ -6,14 +6,26 @@ for (int i = 0; i < 3; i++) { x = i; } + [[likely]] + { + printf("not a matcher\n"); + } + [[clang::matcher_block]] { return x; } +} +constexpr double pow(double x, long long n) noexcept { + if (n > 0) [[likely]] + return x * pow(x, n - 1); + else [[unlikely]] + return 1; } // [[clang::matcher("cuda_kernel")]] // auto kern() { +// [[clang::matcher_block]] // { // kernel<<>>(arg1, arg2, ...); // } @@ -21,25 +33,37 @@ // // [[clang::replace("cuda_kernel")]] // auto hip() { -// { -// hip_launch(kernkel, numblocks, numthreads, 0, 0, arg1, arg2, ...); +// if (kernel == "gaussian") { +// [[clang::matcher_block]] +// { +// hip_launch(kernel, numblocks, numthreads, 0, 0, arg1, arg2, ...); +// } // } // } [[clang::replace("returns")]] auto return42() { - return 42; + [[clang::matcher_block]] + { + return 42; + } } [[clang::insert_before("returns", "thencode")]] auto foobar() { - printf("returning\n");; + [[clang::matcher_block]] + { + printf("returning\n"); + } } [[clang::insert_after("thencode")]] auto helloworld() { - printf("hello world\n"); + [[clang::matcher_block]] + { + printf("hello world\n"); + } } int main() { diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h --- a/clang/include/clang/ASTMatchers/ASTMatchers.h +++ b/clang/include/clang/ASTMatchers/ASTMatchers.h @@ -2737,6 +2737,22 @@ return false; } +/// Matches the statement an attribute is attached to. +/// +/// Example: +/// \code +/// attributedStmt(hasSubStmt(returnStmt())) +/// \endcode +/// would match return 1; here: +/// \code +/// else [[unlikely]] +/// return 1; +/// \endcode +AST_MATCHER_P(AttributedStmt, hasSubStmt, internal::Matcher, InnerMatcher) { + const Stmt *const Statement = Node.getSubStmt(); + return (Statement != nullptr && InnerMatcher.matches(*Statement, Finder, Builder)); +} + /// Matches \c QualTypes in the clang AST. extern const internal::VariadicAllOfMatcher qualType; diff --git a/clang/lib/ASTMatchers/Dynamic/Registry.cpp b/clang/lib/ASTMatchers/Dynamic/Registry.cpp --- a/clang/lib/ASTMatchers/Dynamic/Registry.cpp +++ b/clang/lib/ASTMatchers/Dynamic/Registry.cpp @@ -356,6 +356,7 @@ REGISTER_MATCHER(hasSpecializedTemplate); REGISTER_MATCHER(hasStaticStorageDuration); REGISTER_MATCHER(hasStructuredBlock); + REGISTER_MATCHER(hasSubStmt); REGISTER_MATCHER(hasSyntacticForm); REGISTER_MATCHER(hasTargetDecl); REGISTER_MATCHER(hasTemplateArgument);