diff --git a/clang/tools/CMakeLists.txt b/clang/tools/CMakeLists.txt --- a/clang/tools/CMakeLists.txt +++ b/clang/tools/CMakeLists.txt @@ -8,6 +8,7 @@ add_clang_subdirectory(clang-format-vs) add_clang_subdirectory(clang-fuzzer) add_clang_subdirectory(clang-import-test) +add_clang_subdirectory(clang-nvlink-wrapper) add_clang_subdirectory(clang-offload-bundler) add_clang_subdirectory(clang-offload-wrapper) add_clang_subdirectory(clang-scan-deps) diff --git a/clang/tools/clang-nvlink-wrapper/CMakeLists.txt b/clang/tools/clang-nvlink-wrapper/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/clang/tools/clang-nvlink-wrapper/CMakeLists.txt @@ -0,0 +1,25 @@ +set(LLVM_LINK_COMPONENTS BitWriter Core Object Support) + +if(NOT CLANG_BUILT_STANDALONE) + set(tablegen_deps intrinsics_gen) +endif() + +add_clang_executable(clang-nvlink-wrapper + ClangNvlinkWrapper.cpp + + DEPENDS + ${tablegen_deps} + ) + +set(CLANG_NVLINK_WRAPPER_LIB_DEPS + clangBasic + ) + +add_dependencies(clang clang-nvlink-wrapper) + +target_link_libraries(clang-nvlink-wrapper + PRIVATE + ${CLANG_NVLINK_WRAPPER_LIB_DEPS} + ) + +install(TARGETS clang-nvlink-wrapper RUNTIME DESTINATION bin) diff --git a/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp b/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp new file mode 100644 --- /dev/null +++ b/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp @@ -0,0 +1,220 @@ +//===-- clang-nvlink-wrapper/ClangNvlinkWrapper.cpp - wrapper over nvlink-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file implements a wrapper over nvlink for cubin files extracted from +/// device specific archive libraries. nvlink doesn't accept archive file as +/// an input. +/// +//===---------------------------------------------------------------------===// + +#include "llvm/Object/Archive.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/StringSaver.h" +#include "llvm/Support/WithColor.h" +#include "llvm/Support/raw_ostream.h" + +#if !defined(_MSC_VER) && !defined(__MINGW32__) +#include +#endif + +using namespace llvm; + +// The name this program was invoked as. +static StringRef ToolName; + +const char NVLWHelp[] = R"( +OVERVIEW: Clang Nvlink Wrapper + +USAGE: clang-nvlink-wrapper [options] + +For descriptions of the options please run 'nvlink --help' +The wrapper extracts any arcive objects and call nvlink with the +individual files instead, plus any other options/object. + +)"; + +void printHelpMessage() { outs() << NVLWHelp; } + +// Show the error message and exit. +static void fail(Twine Error) { + WithColor::error(errs(), ToolName) << Error << ".\n"; + printHelpMessage(); + exit(1); +} + +static void failIfError(std::error_code EC, Twine Context = "") { + if (!EC) + return; + + std::string ContextStr = Context.str(); + if (ContextStr.empty()) + fail(EC.message()); + fail(Context + ": " + EC.message()); +} + +static bool isArchiveFile(StringRef Arg) { + if (Arg.startswith("-")) + return false; + + StringRef Extension = sys::path::extension(Arg); + bool isArchive = Extension == ".a"; + return isArchive; +} + +std::vector> ArchiveBuffers; +std::vector> Archives; + +static object::Archive &readArchive(std::unique_ptr Buf) { + ArchiveBuffers.push_back(std::move(Buf)); + auto LibOrErr = + object::Archive::create(ArchiveBuffers.back()->getMemBufferRef()); + failIfError(errorToErrorCode(LibOrErr.takeError()), + "Could not parse library"); + Archives.push_back(std::move(*LibOrErr)); + return *Archives.back(); +} + +static void reportError(Twine Error) { errs() << "ERROR: " << Error << "\n"; } + +static bool reportIfError(std::error_code EC, Twine Context = "") { + if (!EC) + return false; + + std::string ContextStr = Context.str(); + if (ContextStr.empty()) + reportError(EC.message()); + reportError(Context + ": " + EC.message()); + return true; +} + +static bool reportIfError(llvm::Error E, Twine Context = "") { + if (!E) + return false; + ; + + handleAllErrors(std::move(E), [&](const llvm::ErrorInfoBase &EIB) { + std::string ContextStr = Context.str(); + if (ContextStr.empty()) + reportError(EIB.message()); + reportError(Context + ": " + EIB.message()); + }); + return true; +} + +void printNVLinkCommand(std::vector &Command) { + for (auto &Arg : Command) + llvm::errs() << Arg << " "; + llvm::errs() << "\n"; +} + +static void runNVLink(std::string NVLinkPath, + SmallVectorImpl &Args) { + int ExecResult = -1; + // const char *NVLProgram = NVLinkPath.c_str(); + std::vector NVLArgs; + llvm::ErrorOr nvlink = llvm::sys::findProgramByName("nvlink"); + if (!nvlink) { + errs() << "Error: nvlink program not found."; + return; + } + const char *NVLProgram = nvlink.get().c_str(); + NVLArgs.push_back(nvlink.get()); + // NVLArgs.push_back("nvlink"); + for (auto &Arg : Args) { + NVLArgs.push_back(Arg); + } + printNVLinkCommand(NVLArgs); + ExecResult = llvm::sys::ExecuteAndWait(NVLProgram, NVLArgs); + if (ExecResult) { + errs() << "Error: NVlink encountered a problem\n"; + } +} + +static void getArchiveFiles(StringRef Filename, + SmallVectorImpl &Args, + SmallVectorImpl &TmpFiles) { + StringRef IFName = Filename; + llvm::ErrorOr> BufOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(IFName, -1, false); + + if (reportIfError(BufOrErr.getError(), "Can't open file " + IFName)) + return; + + auto &Archive = readArchive(std::move(BufOrErr.get())); + SmallVector SourcePaths; + + llvm::Error Err = llvm::Error::success(); + auto ChildEnd = Archive.child_end(); + for (auto ChildIter = Archive.child_begin(Err); ChildIter != ChildEnd; + ++ChildIter) { + auto ChildNameOrErr = (*ChildIter).getName(); + + if (reportIfError(ChildNameOrErr.takeError(), "No Child Name")) { + continue; + } + StringRef ChildName = llvm::sys::path::filename(ChildNameOrErr.get()); + + auto ChildBufferRefOrErr = (*ChildIter).getMemoryBufferRef(); + reportIfError(ChildBufferRefOrErr.takeError(), "No Child Mem Buf"); + auto ChildBuffer = + MemoryBuffer::getMemBuffer(ChildBufferRefOrErr.get(), false); + auto ChildNameSplit = ChildName.split('.'); + SmallString<16> Path; + int FileDesc; + std::error_code EC = llvm::sys::fs::createTemporaryFile( + (ChildNameSplit.first), (ChildNameSplit.second), FileDesc, Path); + if (reportIfError(EC, "Unable to create temporary file")) { + continue; + } + std::string TmpFileName(Path.str()); + Args.push_back(TmpFileName); + TmpFiles.push_back(TmpFileName); + std::error_code EC1; + llvm::raw_fd_ostream OS(Path.c_str(), EC1, llvm::sys::fs::OF_None); + if (reportIfError(EC, "Unable to write to temporary file")) { + continue; + } + OS << ChildBuffer->getBuffer(); + OS.close(); + } + reportIfError(std::move(Err)); +} + +static void cleanupTmpFiles(SmallVectorImpl &TmpFiles) { + for (auto &TmpFile : TmpFiles) { + std::error_code EC = llvm::sys::fs::remove(TmpFile); + reportIfError(EC, "Unable to delete temporary file"); + } +} + +int main(int argc, char **argv) { + ToolName = argv[0]; + SmallVector Argv(argv, argv + argc); + SmallVector ArgvSubst; + SmallVector TmpFiles; + BumpPtrAllocator Alloc; + StringSaver Saver(Alloc); + cl::ExpandResponseFiles(Saver, cl::TokenizeGNUCommandLine, Argv); + + for (size_t i = 1; i < Argv.size(); ++i) { + std::string Arg = Argv[i]; + if (isArchiveFile(Arg)) { + getArchiveFiles(Arg, ArgvSubst, TmpFiles); + } else { + ArgvSubst.push_back(Arg); + } + } + runNVLink(Argv[1], ArgvSubst); + cleanupTmpFiles(TmpFiles); + return 0; +}