diff --git a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp --- a/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp +++ b/clang/tools/clang-linker-wrapper/ClangLinkerWrapper.cpp @@ -1479,7 +1479,14 @@ auto FileOrErr = wrapDeviceImages(LinkedImages); if (!FileOrErr) return reportError(FileOrErr.takeError()); - LinkerArgs.append(*FileOrErr); + + // We need to insert the new files next to the old ones to make sure they're + // linked with the same libraries / arguments. + auto FirstInput = std::next(llvm::find_if(LinkerArgs, [](StringRef Str) { + return sys::fs::exists(Str) && !sys::fs::is_directory(Str) && + Str != ExecutableName; + })); + LinkerArgs.insert(FirstInput, FileOrErr->begin(), FileOrErr->end()); // Run the host linking job. if (Error Err = runLinker(LinkerUserPath, LinkerArgs)) diff --git a/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp b/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp --- a/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp +++ b/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp @@ -20,6 +20,8 @@ using namespace llvm; namespace { +/// Magic number that begins the section containing the CUDA fatbinary. +constexpr unsigned CudaFatMagic = 0x466243b1; IntegerType *getSizeTTy(Module &M) { LLVMContext &C = M.getContext(); @@ -255,6 +257,265 @@ appendToGlobalDtors(M, Func, /*Priority*/ 1); } +// struct fatbin_wrapper { +// int32_t magic; +// int32_t version; +// void *image; +// void *reserved; +//}; +StructType *getFatbinWrapperTy(Module &M) { + LLVMContext &C = M.getContext(); + StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper"); + if (!FatbinTy) + FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C), + Type::getInt32Ty(C), Type::getInt8PtrTy(C), + Type::getInt8PtrTy(C)); + return FatbinTy; +} + +/// Embed the image \p Image into the module \p M so it can be found by the +/// runtime. +GlobalVariable *createFatbinDesc(Module &M, ArrayRef Image) { + LLVMContext &C = M.getContext(); + llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C); + llvm::Triple Triple = llvm::Triple(M.getTargetTriple()); + + // Create the global string containing the fatbinary. + StringRef FatbinConstantSection = + Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin"; + auto *Data = ConstantDataArray::get(C, Image); + auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true, + GlobalVariable::InternalLinkage, Data, + ".fatbin_image"); + Fatbin->setSection(FatbinConstantSection); + + // Create the fatbinary wrapper + StringRef FatbinWrapperSection = + Triple.isMacOSX() ? "__NV_CUDA,__fatbin" : ".nvFatBinSegment"; + Constant *FatbinWrapper[] = { + ConstantInt::get(Type::getInt32Ty(C), CudaFatMagic), + ConstantInt::get(Type::getInt32Ty(C), 1), + ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy), + ConstantPointerNull::get(Type::getInt8PtrTy(C))}; + + Constant *FatbinInitializer = + ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper); + + auto *FatbinDesc = + new GlobalVariable(M, getFatbinWrapperTy(M), + /*isConstant*/ true, GlobalValue::InternalLinkage, + FatbinInitializer, ".fatbin_wrapper"); + FatbinDesc->setSection(FatbinWrapperSection); + FatbinDesc->setAlignment(Align(8)); + + // We create a dummy entry to ensure the linker will define the begin / end + // symbols. The CUDA runtime should ignore the null address if we attempt to + // register it. + auto *DummyInit = + ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u)); + auto *DummyEntry = new GlobalVariable( + M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit, + "__dummy.cuda_offloading.entry"); + DummyEntry->setSection("cuda_offloading_entries"); + DummyEntry->setVisibility(GlobalValue::HiddenVisibility); + + return FatbinDesc; +} + +/// Create the register globals function. We will iterate all of the offloading +/// entries stored at the begin / end symbols and register them according to +/// their type. This creates the following function in IR: +/// +/// extern struct __tgt_offload_entry __start_cuda_offloading_entries; +/// extern struct __tgt_offload_entry __stop_cuda_offloading_entries; +/// +/// extern void __cudaRegisterFunction(void **, void *, void *, void *, int, +/// void *, void *, void *, void *, int *); +/// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t, +/// int64_t, int32_t, int32_t); +/// +/// void __cudaRegisterTest(void **fatbinHandle) { +/// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries; +/// entry != &__stop_cuda_offloading_entries; ++entry) { +/// if (!entry->size) +/// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name, +/// entry->name, -1, 0, 0, 0, 0, 0); +/// else +/// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name, +/// 0, entry->size, 0, 0); +/// } +/// } +/// +/// TODO: This only registers functions are variables. Additional support is +/// required for texture / surface / managed variables. +Function *createRegisterGlobalsFunction(Module &M) { + LLVMContext &C = M.getContext(); + // Get the __cudaRegisterFunction function declaration. + auto *RegFuncTy = FunctionType::get( + Type::getInt32Ty(C), + {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C), + Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C), + Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), + Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)}, + /*isVarArg*/ false); + FunctionCallee RegFunc = + M.getOrInsertFunction("__cudaRegisterFunction", RegFuncTy); + + // Get the __cudaRegisterVar function declaration. + auto *RegVarTy = FunctionType::get( + Type::getInt32Ty(C), + {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C), + Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C), + getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)}, + /*isVarArg*/ false); + FunctionCallee RegVar = M.getOrInsertFunction("__cudaRegisterVar", RegVarTy); + + // Create the references to the start / stop symbols defined by the linker. + auto *EntriesB = new GlobalVariable( + M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage, + /*Initializer*/ nullptr, "__start_cuda_offloading_entries"); + EntriesB->setVisibility(GlobalValue::HiddenVisibility); + auto *EntriesE = new GlobalVariable( + M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage, + /*Initializer*/ nullptr, "__stop_cuda_offloading_entries"); + EntriesE->setVisibility(GlobalValue::HiddenVisibility); + + auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C), + Type::getInt8PtrTy(C)->getPointerTo(), + /*isVarArg*/ false); + auto *RegGlobalsFn = Function::Create( + RegGlobalsTy, GlobalValue::InternalLinkage, ".cuda.globals_reg", &M); + RegGlobalsFn->setSection(".text.startup"); + + // Create the loop to register all the entries. + IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn)); + auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn); + auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn); + auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn); + auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn); + auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn); + + Builder.CreateBr(EntryBB); + Builder.SetInsertPoint(EntryBB); + auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry"); + auto *AddrPtr = + Builder.CreateInBoundsGEP(getEntryTy(M), Entry, + {ConstantInt::get(getSizeTTy(M), 0), + ConstantInt::get(Type::getInt32Ty(C), 0)}); + auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr"); + auto *NamePtr = + Builder.CreateInBoundsGEP(getEntryTy(M), Entry, + {ConstantInt::get(getSizeTTy(M), 0), + ConstantInt::get(Type::getInt32Ty(C), 1)}); + auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name"); + auto *SizePtr = + Builder.CreateInBoundsGEP(getEntryTy(M), Entry, + {ConstantInt::get(getSizeTTy(M), 0), + ConstantInt::get(Type::getInt32Ty(C), 2)}); + auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size"); + auto *FnCond = + Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M))); + Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB); + Builder.SetInsertPoint(IfThenBB); + Builder.CreateCall(RegFunc, + {RegGlobalsFn->arg_begin(), Addr, Name, Name, + ConstantInt::get(Type::getInt32Ty(C), -1), + ConstantPointerNull::get(Type::getInt8PtrTy(C)), + ConstantPointerNull::get(Type::getInt8PtrTy(C)), + ConstantPointerNull::get(Type::getInt8PtrTy(C)), + ConstantPointerNull::get(Type::getInt8PtrTy(C)), + ConstantPointerNull::get(Type::getInt32PtrTy(C))}); + Builder.CreateBr(IfEndBB); + Builder.SetInsertPoint(IfElseBB); + Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name, + ConstantInt::get(Type::getInt32Ty(C), 0), Size, + ConstantInt::get(Type::getInt32Ty(C), 0), + ConstantInt::get(Type::getInt32Ty(C), 0)}); + Builder.CreateBr(IfEndBB); + Builder.SetInsertPoint(IfEndBB); + auto *NewEntry = Builder.CreateInBoundsGEP( + getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1)); + auto *Cmp = Builder.CreateICmpEQ(NewEntry, EntriesE); + Entry->addIncoming(EntriesB, &RegGlobalsFn->getEntryBlock()); + Entry->addIncoming(NewEntry, IfEndBB); + Builder.CreateCondBr(Cmp, ExitBB, EntryBB); + Builder.SetInsertPoint(ExitBB); + Builder.CreateRetVoid(); + + return RegGlobalsFn; +} + +// Create the constructor and destructor to register the fatbinary with the CUDA +// runtime. +void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc) { + LLVMContext &C = M.getContext(); + auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); + auto *CtorFunc = Function::Create(CtorFuncTy, GlobalValue::InternalLinkage, + ".cuda.fatbin_reg", &M); + CtorFunc->setSection(".text.startup"); + + auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); + auto *DtorFunc = Function::Create(DtorFuncTy, GlobalValue::InternalLinkage, + ".cuda.fatbin_unreg", &M); + DtorFunc->setSection(".text.startup"); + + // Get the __cudaRegisterFatBinary function declaration. + auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(), + Type::getInt8PtrTy(C), + /*isVarArg*/ false); + FunctionCallee RegFatbin = + M.getOrInsertFunction("__cudaRegisterFatBinary", RegFatTy); + // Get the __cudaRegisterFatBinaryEnd function declaration. + auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C), + Type::getInt8PtrTy(C)->getPointerTo(), + /*isVarArg*/ false); + FunctionCallee RegFatbinEnd = + M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy); + // Get the __cudaUnregisterFatBinary function declaration. + auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C), + Type::getInt8PtrTy(C)->getPointerTo(), + /*isVarArg*/ false); + FunctionCallee UnregFatbin = + M.getOrInsertFunction("__cudaUnregisterFatBinary", UnregFatTy); + + auto *AtExitTy = + FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(), + /*isVarArg*/ false); + FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy); + + auto *BinaryHandleGlobal = new llvm::GlobalVariable( + M, Type::getInt8PtrTy(C)->getPointerTo(), false, + llvm::GlobalValue::InternalLinkage, + llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()), + ".cuda.binary_handle"); + + // Create the constructor to register this image with the runtime. + IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc)); + CallInst *Handle = CtorBuilder.CreateCall( + RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast( + FatbinDesc, Type::getInt8PtrTy(C))); + CtorBuilder.CreateAlignedStore( + Handle, BinaryHandleGlobal, + Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C)))); + CtorBuilder.CreateCall(createRegisterGlobalsFunction(M), Handle); + CtorBuilder.CreateCall(RegFatbinEnd, Handle); + CtorBuilder.CreateCall(AtExit, DtorFunc); + CtorBuilder.CreateRetVoid(); + + // Create the destructor to unregister the image with the runtime. We cannot + // use a standard global destructor after CUDA 9.2 so this must be called by + // `atexit()` intead. + IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc)); + LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad( + Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal, + Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C)))); + DtorBuilder.CreateCall(UnregFatbin, BinaryHandle); + DtorBuilder.CreateRetVoid(); + + // Add this function to constructors. + appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1); +} + } // namespace Error wrapOpenMPBinaries(Module &M, ArrayRef> Images) { @@ -267,7 +528,12 @@ return Error::success(); } -llvm::Error wrapCudaBinary(llvm::Module &M, llvm::ArrayRef Images) { - return createStringError(inconvertibleErrorCode(), - "Cuda wrapping is not yet supported."); +Error wrapCudaBinary(Module &M, ArrayRef Image) { + GlobalVariable *Desc = createFatbinDesc(M, Image); + if (!Desc) + return createStringError(inconvertibleErrorCode(), + "No fatinbary section created."); + + createRegisterFatbinFunction(M, Desc); + return Error::success(); }