diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -8,6 +8,8 @@ if (MLIR_ENABLE_ROCM_CONVERSIONS) set(AMDGPU_LIBS + IRReader + linker MCParser AMDGPUAsmParser AMDGPUCodeGen diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp --- a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp @@ -19,6 +19,11 @@ #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" + #include "llvm/MC/MCAsmBackend.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCCodeEmitter.h" @@ -138,6 +143,184 @@ gpu::SerializeToBlobPass::getDependentDialects(registry); } +static Optional, 3>> +loadLibraries(SmallVectorImpl &path, + SmallVectorImpl &libraries, + llvm::LLVMContext &context) { + SmallVector, 3> ret; + auto dirLength = path.size(); + + if (!llvm::sys::fs::is_directory(path)) { + llvm::dbgs() << "Bitcode path: " << path + << " does not exist or is not a directory\n"; + return llvm::None; + } + + for (const auto &file : libraries) { + llvm::SMDiagnostic error; + llvm::sys::path::append(path, file); + llvm::StringRef pathRef(path.data(), path.size()); + std::unique_ptr library = + llvm::getLazyIRFileModule(pathRef, error, context); + path.set_size(dirLength); + if (!library) { + llvm::dbgs() << "Failed to load library " << file << " from " << path; + error.print("[MLIR backend]", llvm::dbgs()); + return llvm::None; + } + // Some ROCM builds don't strip this like they should + if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version")) { + library->eraseNamedMetadata(openclVersion); + } + // Stop spamming us with clang version numbers + if (auto *ident = library->getNamedMetadata("llvm.ident")) { + library->eraseNamedMetadata(ident); + } + ret.push_back(std::move(library)); + } + + return ret; +} + +std::unique_ptr +SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) { + // MLIR -> LLVM translation + std::unique_ptr ret = + gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext); + + if (!ret) { + llvm::dbgs() << "Module creation failed"; + return ret; + } + // Walk the LLVM module in order to determine if we need to link in device + // libs + bool needOpenCl = false; + bool needOckl = false; + bool needOcml = false; + for (auto &f : ret->functions()) { + if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) { + StringRef funcName = f.getName(); + if ("printf" == funcName) { + needOpenCl = true; + } + if (funcName.startswith("__ockl_")) { + needOckl = true; + } + if (funcName.startswith("__ocml_")) { + needOcml = true; + } + } + } + + if (needOpenCl) { + needOcml = needOckl = true; + } + + // No libraries needed (the typical case) + if (!(needOpenCl || needOcml || needOckl)) { + return ret; + } + + auto addControlConstant = [&module = *ret](StringRef name, uint32_t value, + uint32_t bitwidth) { + using llvm::GlobalVariable; + if (module.getNamedGlobal(name)) { + return; + } + llvm::IntegerType *type = + llvm::IntegerType::getIntNTy(module.getContext(), bitwidth); + auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false); + auto *constant = new GlobalVariable( + module, type, + /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage, + initializer, name, + /*before=*/nullptr, + /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal, + /*addressSpace=*/4); + constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local); + constant->setVisibility( + GlobalVariable::VisibilityTypes::ProtectedVisibility); + constant->setAlignment(llvm::MaybeAlign(bitwidth / 8)); + }; + + // Set up control variables in the module instead of linking in tiny bitcode + if (needOcml) { + // TODO(kdrewnia): Enable math optimizations once we have support for + // `-ffast-math`-like options + addControlConstant("__oclc_finite_only_opt", 0, 8); + addControlConstant("__oclc_daz_opt", 0, 8); + addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8); + addControlConstant("__oclc_unsafe_math_opt", 0, 8); + } + if (needOcml || needOckl) { + addControlConstant("__oclc_wavefrontsize64", 1, 8); + StringRef chipSet = this->chip.getValue(); + if (chipSet.startswith("gfx")) { + chipSet = chipSet.substr(3); + } + uint32_t minor = + llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue(); + uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10) + .getZExtValue(); + uint32_t isaNumber = minor + 1000 * major; + addControlConstant("__oclc_ISA_version", isaNumber, 32); + } + + // Determine libraries we need to link + llvm::SmallVector libraries; + if (needOpenCl) { + libraries.push_back("opencl.bc"); + } + if (needOcml) { + libraries.push_back("ocml.bc"); + } + if (needOckl) { + libraries.push_back("ockl.bc"); + } + + Optional, 3>> mbModules; + auto theRocmPath = getRocmPath(); + llvm::SmallString<32> bitcodePath(theRocmPath); + llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode"); + mbModules = loadLibraries(bitcodePath, libraries, llvmContext); + + // Handle legacy override variable + auto env = llvm::sys::Process::GetEnv("HIP_DEVICE_LIB_PATH"); + if (env && (rocmPath.getNumOccurrences() == 0)) { + llvm::SmallString<32> overrideValue(env.getValue()); + auto mbAtOldPath = loadLibraries(overrideValue, libraries, llvmContext); + if (mbAtOldPath) { + mbModules = std::move(mbAtOldPath); + } + } + + if (!mbModules) { + llvm::WithColor::warning(llvm::errs()) + << "Warning: Could not load required device labraries\n"; + llvm::WithColor::note(llvm::errs()) + << "Note: this will probably cause link-time or run-time failures\n"; + return ret; // We can still abort here + } + + llvm::Linker linker(*ret); + for (auto &libModule : mbModules.getValue()) { + // Failure is true + auto err = linker.linkInModule( + std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded, + [](llvm::Module &m, const StringSet<> &gvs) { + llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) { + return !gv.hasName() || (gvs.count(gv.getName()) == 0); + }); + }); + if (err) { + llvm::errs() << "Error: Failure in library bitcode linking\n"; + // We have no guaranties about the state of `ret`, so bail + return nullptr; + } + } + return ret; +} + LogicalResult SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule, llvm::TargetMachine &targetMachine) {