Index: clang/tools/clang-offload-bundler/ClangOffloadBundler.cpp =================================================================== --- clang/tools/clang-offload-bundler/ClangOffloadBundler.cpp +++ clang/tools/clang-offload-bundler/ClangOffloadBundler.cpp @@ -22,6 +22,8 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Triple.h" +#include "llvm/Object/Archive.h" +#include "llvm/Object/ArchiveWriter.h" #include "llvm/Object/Binary.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Support/Casting.h" @@ -29,6 +31,7 @@ #include "llvm/Support/Errc.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorOr.h" +#include "llvm/Support/Host.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" @@ -80,6 +83,7 @@ " bc - llvm-bc\n" " s - assembler\n" " o - object\n" + " a - archive of objects\n" " gch - precompiled-header\n" " ast - clang AST file"), cl::cat(ClangOffloadBundlerCategory)); @@ -118,6 +122,22 @@ return OffloadKind == "host"; } +static StringRef getTriple(StringRef Target) { + StringRef OffloadKind; + StringRef Triple; + getOffloadKindAndTriple(Target, OffloadKind, Triple); + return Triple; +} + +static StringRef getDevice(StringRef Triple) { + if (Triple.contains("-")) { + auto Split = Triple.rsplit('-'); + return Split.second; + } else { + return Triple; + } +} + /// Generic file handler interface. class FileHandler { public: @@ -139,7 +159,7 @@ virtual Error ReadBundleEnd(MemoryBuffer &Input) = 0; /// Read the current bundle and write the result into the stream \a OS. - virtual Error ReadBundle(raw_fd_ostream &OS, MemoryBuffer &Input) = 0; + virtual Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) = 0; /// Write the header of the bundled file to \a OS based on the information /// gathered from \a Inputs. @@ -308,7 +328,7 @@ return Error::success(); } - Error ReadBundle(raw_fd_ostream &OS, MemoryBuffer &Input) final { + Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final { assert(CurBundleInfo != BundlesInfo.end() && "Invalid reader info!"); StringRef FC = Input.getBuffer(); OS.write(FC.data() + CurBundleInfo->second.Offset, @@ -466,7 +486,7 @@ Error ReadBundleEnd(MemoryBuffer &Input) final { return Error::success(); } - Error ReadBundle(raw_fd_ostream &OS, MemoryBuffer &Input) final { + Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final { Expected ContentOrErr = CurrentSection->getContents(); if (!ContentOrErr) return ContentOrErr.takeError(); @@ -660,7 +680,7 @@ return Error::success(); } - Error ReadBundle(raw_fd_ostream &OS, MemoryBuffer &Input) final { + Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final { StringRef FC = Input.getBuffer(); size_t BundleStart = ReadChars; @@ -743,6 +763,8 @@ return std::make_unique(/*Comment=*/"#"); if (FilesType == "o") return CreateObjectFileHandler(FirstInput); + if (FilesType == "a") + return CreateObjectFileHandler(FirstInput); if (FilesType == "gch") return std::make_unique(); if (FilesType == "ast") @@ -902,6 +924,141 @@ return Error::success(); } +static Archive::Kind getDefaultArchiveKindForHost() { + return Triple(sys::getDefaultTargetTriple()).isOSDarwin() + ? Archive::K_DARWIN + : Archive::K_GNU; +} + +static StringRef getDeviceFileExtension(StringRef Device) { + if (Device.contains("gfx")) + return ".bc"; + else + return ".cubin"; +} + +static StringRef removeExtension(StringRef FileName) { + StringRef NoExtFileName; + if (FileName.contains(".")) + NoExtFileName = FileName.rsplit('.').first; + else + NoExtFileName = FileName; + + return NoExtFileName; +} + +static std::string getDeviceLibraryFileName(StringRef BundleFileName, + StringRef Device) { + StringRef LibName = removeExtension(BundleFileName); + StringRef Extension = getDeviceFileExtension(Device); + + std::string Result; + Result += LibName; + Result += Extension; + return Result; +} + +static bool checkDeviceOptions(StringRef Device, std::string OffloadArch) { + return !OffloadArch.empty() && OffloadArch == Device; +} + +static Error UnbundleArchive() { + std::vector> ArchiveBuffers; + std::string OffloadArch = getDevice(TargetNames.front()).str(); + std::vector ArchiveMembers; + + StringRef IFName = InputFileNames.front(); + ErrorOr> BufOrErr = + MemoryBuffer::getFileOrSTDIN(IFName, -1, false); + if (std::error_code EC = BufOrErr.getError()) + return createFileError(InputFileNames.front(), EC); + + ArchiveBuffers.push_back(std::move(*BufOrErr)); + auto LibOrErr = + Archive::create(ArchiveBuffers.back()->getMemBufferRef()); + if (!LibOrErr) + return LibOrErr.takeError(); + + auto Archive = std::move(*LibOrErr); + + Error ArchiveErr = Error::success(); + auto ChildEnd = Archive->child_end(); + for (auto ChildIter = Archive->child_begin(ArchiveErr); + ChildIter != ChildEnd; ++ChildIter) { + if (ArchiveErr) + return ArchiveErr; + auto ChildNameOrErr = (*ChildIter).getName(); + if (!ChildNameOrErr) + return ChildNameOrErr.takeError(); + + StringRef ChildName = sys::path::filename(*ChildNameOrErr); + + auto ChildBufferRefOrErr = (*ChildIter).getMemoryBufferRef(); + if (!ChildBufferRefOrErr) + return ChildBufferRefOrErr.takeError(); + + auto ChildBuffer = + MemoryBuffer::getMemBuffer(*ChildBufferRefOrErr, false); + + Expected> FileHandlerOrErr = + CreateFileHandler(*ChildBuffer); + if (!FileHandlerOrErr) + return FileHandlerOrErr.takeError(); + + std::unique_ptr &FileHandler = *FileHandlerOrErr; + assert(FileHandler); + + if (Error ReadErr = FileHandler.get()->ReadHeader(*ChildBuffer)) + return ReadErr; + + Expected> CurTripleOrErr = + FileHandler->ReadBundleStart(*ChildBuffer); + if (!CurTripleOrErr) + return CurTripleOrErr.takeError(); + + StringRef CurKindTriple = **CurTripleOrErr; + assert(!CurKindTriple.empty()); + + while (!CurKindTriple.empty()) { + if (hasHostKind(CurKindTriple)) { + // Do nothing, we don't extract host code yet + } else if (checkDeviceOptions(getDevice(getTriple(CurKindTriple)), + OffloadArch)) { + std::string BundleData; + raw_string_ostream DataStream(BundleData); + if (Error Err = FileHandler.get()->ReadBundle(DataStream, *ChildBuffer)) + return Err; + + std::string *LibraryName = + new std::string(getDeviceLibraryFileName(ChildName, OffloadArch)); + auto MemBuf = + MemoryBuffer::getMemBufferCopy(DataStream.str(), *LibraryName); + ArchiveBuffers.push_back(std::move(MemBuf)); + auto MemBufRef = MemoryBufferRef(*(ArchiveBuffers.back())); + ArchiveMembers.push_back(NewArchiveMember(MemBufRef)); + } + if (Error Err = FileHandler.get()->ReadBundleEnd(*ChildBuffer)) + return Err; + + Expected> NextTripleOrErr = + FileHandler->ReadBundleStart(*ChildBuffer); + if (!NextTripleOrErr) + return NextTripleOrErr.takeError(); + + CurKindTriple = ((*NextTripleOrErr).hasValue()) ? **NextTripleOrErr : ""; + } + } + assert(!ArchiveErr); + + std::string FileName = OutputFileNames.front(); + if (Error WriteErr = + writeArchive(FileName, ArchiveMembers, true, + getDefaultArchiveKindForHost(), true, false, nullptr)) + return WriteErr; + + return Error::success(); +} + static void PrintVersion(raw_ostream &OS) { OS << clang::getClangToolFullVersion("clang-offload-bundler") << '\n'; } @@ -935,7 +1092,13 @@ errc::invalid_argument, "only one input file supported in unbundling mode")); } - if (OutputFileNames.size() != TargetNames.size()) { + if (FilesType == "a" && (OutputFileNames.size() != 1 || + TargetNames.size() != 1)) { + Error = true; + reportError(createStringError(errc::invalid_argument, + "number of output files and targets should " + "be 1 when unbundling an archive")); + } else if (OutputFileNames.size() != TargetNames.size()) { Error = true; reportError(createStringError(errc::invalid_argument, "number of output files and targets should " @@ -1014,7 +1177,11 @@ // tools. BundlerExecutable = sys::fs::getMainExecutable(argv[0], &BundlerExecutable); - if (llvm::Error Err = Unbundle ? UnbundleFiles() : BundleFiles()) { + llvm::Error Err = (Unbundle && FilesType == "a") ? UnbundleArchive() + : (Unbundle) ? UnbundleFiles() + : BundleFiles(); + + if (Err) { reportError(std::move(Err)); return 1; }