diff --git a/clang/lib/Driver/ToolChains/Cuda.cpp b/clang/lib/Driver/ToolChains/Cuda.cpp --- a/clang/lib/Driver/ToolChains/Cuda.cpp +++ b/clang/lib/Driver/ToolChains/Cuda.cpp @@ -613,6 +613,13 @@ AddStaticDeviceLibsLinking(C, *this, JA, Inputs, Args, CmdArgs, "nvptx", GPUArch, false, false); + // Find nvlink and pass it as "--path=" argument of clang-nvlink-wrapper. + auto NvlinkDir = + llvm::sys::path::parent_path(getToolChain().GetProgramPath("nvlink")) + .str(); + const char *NvlinkPath = Args.MakeArgString(Twine("--path=" + NvlinkDir)); + CmdArgs.push_back(NvlinkPath); + const char *Exec = Args.MakeArgString(getToolChain().GetProgramPath("clang-nvlink-wrapper")); C.addCommand(std::make_unique( diff --git a/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp b/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp --- a/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp +++ b/clang/tools/clang-nvlink-wrapper/ClangNvlinkWrapper.cpp @@ -41,6 +41,15 @@ static cl::opt Help("h", cl::desc("Alias for -help"), cl::Hidden); +// Mark all our options with this category, everything else (except for -help) +// will be hidden. +static cl::OptionCategory + ClangNvlinkWrapperCategory("clang-nvlink-wrapper options"); + +static cl::opt + NvlinkUserPath("path", cl::desc("path of directory containing nvlink"), + cl::cat(ClangNvlinkWrapperCategory)); + static Error runNVLink(std::string NVLinkPath, SmallVectorImpl &Args) { std::vector NVLArgs; @@ -121,7 +130,6 @@ int main(int argc, const char **argv) { sys::PrintStackTraceOnErrorSignal(argv[0]); - if (Help) { cl::PrintHelpMessage(); return 0; @@ -132,12 +140,7 @@ exit(1); }; - ErrorOr NvlinkPath = sys::findProgramByName("nvlink"); - if (!NvlinkPath) { - reportError(createStringError(NvlinkPath.getError(), - "unable to find 'nvlink' in path")); - } - + std::string NvlinkPath; SmallVector Argv(argv, argv + argc); SmallVector ArgvSubst; SmallVector TmpFiles; @@ -147,15 +150,28 @@ for (size_t i = 1; i < Argv.size(); ++i) { std::string Arg = Argv[i]; + StringRef ArgRef(Arg); + auto NvlPath = ArgRef.startswith_insensitive("--path="); if (sys::path::extension(Arg) == ".a") { if (Error Err = extractArchiveFiles(Arg, ArgvSubst, TmpFiles)) reportError(std::move(Err)); + } else if (NvlPath) { + NvlinkPath = ArgRef.substr(7).str().append("/nvlink"); } else { ArgvSubst.push_back(Arg); } } - if (Error Err = runNVLink(NvlinkPath.get(), ArgvSubst)) + if (NvlinkPath.empty()) { + ErrorOr NvlinkPathErr = sys::findProgramByName("nvlink"); + if (!NvlinkPathErr) { + reportError(createStringError(NvlinkPathErr.getError(), + "unable to find 'nvlink' in path")); + } + NvlinkPath = NvlinkPathErr.get(); + } + + if (Error Err = runNVLink(NvlinkPath, ArgvSubst)) reportError(std::move(Err)); if (Error Err = cleanupTmpFiles(TmpFiles)) reportError(std::move(Err));