diff --git a/clang/lib/Interpreter/IncrementalParser.h b/clang/lib/Interpreter/IncrementalParser.h --- a/clang/lib/Interpreter/IncrementalParser.h +++ b/clang/lib/Interpreter/IncrementalParser.h @@ -57,6 +57,9 @@ /// of code. std::list PTUs; + /// Callbacks executed with the parsed result of every incremntal input + std::vector> ASTHandlers; + IncrementalParser(); public: diff --git a/clang/lib/Interpreter/IncrementalParser.cpp b/clang/lib/Interpreter/IncrementalParser.cpp --- a/clang/lib/Interpreter/IncrementalParser.cpp +++ b/clang/lib/Interpreter/IncrementalParser.cpp @@ -192,10 +192,15 @@ Sema::ModuleImportState ImportState; for (bool AtEOF = P->ParseFirstTopLevelDecl(ADecl, ImportState); !AtEOF; AtEOF = P->ParseTopLevelDecl(ADecl, ImportState)) { - if (ADecl && !Consumer->HandleTopLevelDecl(ADecl.get())) - return llvm::make_error("Parsing failed. " - "The consumer rejected a decl", - std::error_code()); + if (ADecl) { + auto DeclGroup = ADecl.get(); + for (auto handler : ASTHandlers) { + handler(DeclGroup); + } + if (!Consumer->HandleTopLevelDecl(DeclGroup)) + return llvm::make_error( + "Parsing failed. The consumer rejected a decl", std::error_code()); + } } DiagnosticsEngine &Diags = getCI()->getDiagnostics(); diff --git a/clang/lib/Interpreter/Offload.cpp b/clang/lib/Interpreter/Offload.cpp --- a/clang/lib/Interpreter/Offload.cpp +++ b/clang/lib/Interpreter/Offload.cpp @@ -20,6 +20,17 @@ namespace clang { +static void inlineCudaDeviceFunctions(DeclGroupRef DG) { + for (auto *Decl : DG) { + if (Decl->isFunctionOrFunctionTemplate()) { + auto *Func = Decl->getAsFunction(); + if (Func->hasAttr()) { + Func->setInlineSpecified(true); + } + } + } +} + IncrementalCUDADeviceParser::IncrementalCUDADeviceParser( std::unique_ptr Instance, llvm::LLVMContext &LLVMCtx, llvm::StringRef Arch, llvm::StringRef FatbinFile, llvm::Error &Err) @@ -27,6 +38,8 @@ if (Err) return; + ASTHandlers.push_back(inlineCudaDeviceFunctions); + if (!Arch.starts_with("sm_") || Arch.substr(3).getAsInteger(10, SMVersion)) { llvm::errs() << Arch.substr(3) << SMVersion << '\n';