diff --git a/clang/lib/Driver/ToolChains/Cuda.cpp b/clang/lib/Driver/ToolChains/Cuda.cpp --- a/clang/lib/Driver/ToolChains/Cuda.cpp +++ b/clang/lib/Driver/ToolChains/Cuda.cpp @@ -613,6 +613,11 @@ AddStaticDeviceLibsLinking(C, *this, JA, Inputs, Args, CmdArgs, "nvptx", GPUArch, false, false); + // Find nvlink and pass it as "--nvlink-path=" argument of + // clang-nvlink-wrapper. + CmdArgs.push_back(Args.MakeArgString( + Twine("--nvlink-path=" + getToolChain().GetProgramPath("nvlink")))); + const char *Exec = Args.MakeArgString(getToolChain().GetProgramPath("clang-nvlink-wrapper")); C.addCommand(std::make_unique( diff --git a/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp b/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp --- a/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp +++ b/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp @@ -25,6 +25,7 @@ /// 2. nvlink -o a.out-openmp-nvptx64 /tmp/a.cubin /tmp/b.cubin //===---------------------------------------------------------------------===// +#include "clang/Basic/Version.h" #include "llvm/Object/Archive.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Errc.h" @@ -41,6 +42,19 @@ static cl::opt Help("h", cl::desc("Alias for -help"), cl::Hidden); +// Mark all our options with this category, everything else (except for -help) +// will be hidden. +static cl::OptionCategory + ClangNvlinkWrapperCategory("clang-nvlink-wrapper options"); + +static cl::opt NvlinkUserPath("nvlink-path", + cl::desc("Path of nvlink binary"), + cl::cat(ClangNvlinkWrapperCategory)); + +// Do not parse nvlink options +static cl::list + NVArgs(cl::Sink, cl::desc("...")); + static Error runNVLink(std::string NVLinkPath, SmallVectorImpl &Args) { std::vector NVLArgs; @@ -119,8 +133,20 @@ return Error::success(); } +static void PrintVersion(raw_ostream &OS) { + OS << clang::getClangToolFullVersion("clang-nvlink-wrapper") << '\n'; +} + int main(int argc, const char **argv) { sys::PrintStackTraceOnErrorSignal(argv[0]); + cl::SetVersionPrinter(PrintVersion); + cl::HideUnrelatedOptions(ClangNvlinkWrapperCategory); + cl::ParseCommandLineOptions( + argc, argv, + "A wrapper tool over nvlink program. It transparently passes every \n" + "input option and objects to nvlink except archive files and path of \n" + "nvlink binary. It reads each input archive file to extract archived \n" + "cubin files as temporary files.\n"); if (Help) { cl::PrintHelpMessage(); @@ -132,12 +158,7 @@ exit(1); }; - ErrorOr NvlinkPath = sys::findProgramByName("nvlink"); - if (!NvlinkPath) { - reportError(createStringError(NvlinkPath.getError(), - "unable to find 'nvlink' in path")); - } - + std::string NvlinkPath; SmallVector Argv(argv, argv + argc); SmallVector ArgvSubst; SmallVector TmpFiles; @@ -145,8 +166,7 @@ StringSaver Saver(Alloc); cl::ExpandResponseFiles(Saver, cl::TokenizeGNUCommandLine, Argv); - for (size_t i = 1; i < Argv.size(); ++i) { - std::string Arg = Argv[i]; + for (const std::string &Arg : NVArgs) { if (sys::path::extension(Arg) == ".a") { if (Error Err = extractArchiveFiles(Arg, ArgvSubst, TmpFiles)) reportError(std::move(Err)); @@ -155,7 +175,19 @@ } } - if (Error Err = runNVLink(NvlinkPath.get(), ArgvSubst)) + NvlinkPath = NvlinkUserPath; + + // If user hasn't specified nvlink binary then search it in PATH + if (NvlinkPath.empty()) { + ErrorOr NvlinkPathErr = sys::findProgramByName("nvlink"); + if (!NvlinkPathErr) { + reportError(createStringError(NvlinkPathErr.getError(), + "unable to find 'nvlink' in path")); + } + NvlinkPath = NvlinkPathErr.get(); + } + + if (Error Err = runNVLink(NvlinkPath, ArgvSubst)) reportError(std::move(Err)); if (Error Err = cleanupTmpFiles(TmpFiles)) reportError(std::move(Err));