diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -8,6 +8,8 @@ if (MLIR_ENABLE_ROCM_CONVERSIONS) set(AMDGPU_LIBS + IRReader + linker MCParser AMDGPUAsmParser AMDGPUCodeGen diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp --- a/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToHsaco.cpp @@ -21,6 +21,12 @@ #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" + #include "llvm/MC/MCAsmBackend.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCCodeEmitter.h" @@ -42,6 +48,8 @@ #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "llvm/Transforms/IPO/Internalize.h" + #include "lld/Common/Driver.h" #include @@ -69,6 +77,10 @@ Option rocmPath{*this, "rocm-path", llvm::cl::desc("Path to ROCm install")}; + // Overload to allow linking in device libs + std::unique_ptr + translateToLLVMIR(llvm::LLVMContext &llvmContext) override; + /// Adds LLVM optimization passes LogicalResult optimizeLlvm(llvm::Module &llvmModule, llvm::TargetMachine &targetMachine) override; @@ -76,6 +88,12 @@ private: void getDependentDialects(DialectRegistry ®istry) const override; + // Loads LLVM bitcode libraries + Optional, 3>> + loadLibraries(SmallVectorImpl &path, + SmallVectorImpl &libraries, + llvm::LLVMContext &context); + // Serializes ROCDL to HSACO. std::unique_ptr> serializeISA(const std::string &isa) override; @@ -123,6 +141,175 @@ gpu::SerializeToBlobPass::getDependentDialects(registry); } +Optional, 3>> +SerializeToHsacoPass::loadLibraries(SmallVectorImpl &path, + SmallVectorImpl &libraries, + llvm::LLVMContext &context) { + SmallVector, 3> ret; + size_t dirLength = path.size(); + + if (!llvm::sys::fs::is_directory(path)) { + getOperation().emitRemark() << "Bitcode path: " << path + << " does not exist or is not a directory\n"; + return llvm::None; + } + + for (const StringRef file : libraries) { + llvm::SMDiagnostic error; + llvm::sys::path::append(path, file); + llvm::StringRef pathRef(path.data(), path.size()); + std::unique_ptr library = + llvm::getLazyIRFileModule(pathRef, error, context); + path.set_size(dirLength); + if (!library) { + getOperation().emitError() << "Failed to load library " << file + << " from " << path << error.getMessage(); + return llvm::None; + } + // Some ROCM builds don't strip this like they should + if (auto *openclVersion = library->getNamedMetadata("opencl.ocl.version")) + library->eraseNamedMetadata(openclVersion); + // Stop spamming us with clang version numbers + if (auto *ident = library->getNamedMetadata("llvm.ident")) + library->eraseNamedMetadata(ident); + ret.push_back(std::move(library)); + } + + return ret; +} + +std::unique_ptr +SerializeToHsacoPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) { + // MLIR -> LLVM translation + std::unique_ptr ret = + gpu::SerializeToBlobPass::translateToLLVMIR(llvmContext); + + if (!ret) { + getOperation().emitOpError("Module lowering failed"); + return ret; + } + // Walk the LLVM module in order to determine if we need to link in device + // libs + bool needOpenCl = false; + bool needOckl = false; + bool needOcml = false; + for (llvm::Function &f : ret->functions()) { + if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) { + StringRef funcName = f.getName(); + if ("printf" == funcName) + needOpenCl = true; + if (funcName.startswith("__ockl_")) + needOckl = true; + if (funcName.startswith("__ocml_")) + needOcml = true; + } + } + + if (needOpenCl) + needOcml = needOckl = true; + + // No libraries needed (the typical case) + if (!(needOpenCl || needOcml || needOckl)) + return ret; + + // Define one of the control constants the ROCm device libraries expect to be + // present These constants can either be defined in the module or can be + // imported by linking in bitcode that defines the constant. To simplify our + // logic, we define the constants into the module we are compiling + auto addControlConstant = [&module = *ret](StringRef name, uint32_t value, + uint32_t bitwidth) { + using llvm::GlobalVariable; + if (module.getNamedGlobal(name)) { + return; + } + llvm::IntegerType *type = + llvm::IntegerType::getIntNTy(module.getContext(), bitwidth); + auto *initializer = llvm::ConstantInt::get(type, value, /*isSigned=*/false); + auto *constant = new GlobalVariable( + module, type, + /*isConstant=*/true, GlobalVariable::LinkageTypes::LinkOnceODRLinkage, + initializer, name, + /*before=*/nullptr, + /*threadLocalMode=*/GlobalVariable::ThreadLocalMode::NotThreadLocal, + /*addressSpace=*/4); + constant->setUnnamedAddr(GlobalVariable::UnnamedAddr::Local); + constant->setVisibility( + GlobalVariable::VisibilityTypes::ProtectedVisibility); + constant->setAlignment(llvm::MaybeAlign(bitwidth / 8)); + }; + + if (needOcml) { + // TODO(kdrewnia): Enable math optimizations once we have support for + // `-ffast-math`-like options + addControlConstant("__oclc_finite_only_opt", 0, 8); + addControlConstant("__oclc_daz_opt", 0, 8); + addControlConstant("__oclc_correctly_rounded_sqrt32", 1, 8); + addControlConstant("__oclc_unsafe_math_opt", 0, 8); + } + if (needOcml || needOckl) { + addControlConstant("__oclc_wavefrontsize64", 1, 8); + StringRef chipSet = this->chip.getValue(); + if (chipSet.startswith("gfx")) { + chipSet = chipSet.substr(3); + } + uint32_t minor = + llvm::APInt(32, chipSet.substr(chipSet.size() - 2), 16).getZExtValue(); + uint32_t major = llvm::APInt(32, chipSet.substr(0, chipSet.size() - 2), 10) + .getZExtValue(); + uint32_t isaNumber = minor + 1000 * major; + addControlConstant("__oclc_ISA_version", isaNumber, 32); + } + + // Determine libraries we need to link - order matters due to dependencies + llvm::SmallVector libraries; + if (needOpenCl) + libraries.push_back("opencl.bc"); + if (needOcml) + libraries.push_back("ocml.bc"); + if (needOckl) + libraries.push_back("ockl.bc"); + + Optional, 3>> mbModules; + std::string theRocmPath = getRocmPath(); + llvm::SmallString<32> bitcodePath(std::move(theRocmPath)); + llvm::sys::path::append(bitcodePath, "amdgcn", "bitcode"); + mbModules = loadLibraries(bitcodePath, libraries, llvmContext); + + if (!mbModules) { + getOperation() + .emitWarning("Could not load required device labraries") + .attachNote() + << "This will probably cause link-time or run-time failures"; + return ret; // We can still abort here + } + + llvm::Linker linker(*ret); + for (std::unique_ptr &libModule : mbModules.getValue()) { + // This bitcode linking code is substantially similar to what is used in + // hip-clang It imports the library functions into the module, allowing LLVM + // optimization passes (which must run after linking) to optimize across the + // libraries and the module's code. We also only import symbols if they are + // referenced by the module or a previous library since there will be no + // other source of references to those symbols in this compilation and since + // we don't want to bloat the resulting code object. + bool err = linker.linkInModule( + std::move(libModule), llvm::Linker::Flags::LinkOnlyNeeded, + [](llvm::Module &m, const StringSet<> &gvs) { + llvm::internalizeModule(m, [&gvs](const llvm::GlobalValue &gv) { + return !gv.hasName() || (gvs.count(gv.getName()) == 0); + }); + }); + // True is linker failure + if (err) { + getOperation().emitError( + "Unrecoverable failure during device library linking."); + // We have no guaranties about the state of `ret`, so bail + return nullptr; + } + } + return ret; +} + LogicalResult SerializeToHsacoPass::optimizeLlvm(llvm::Module &llvmModule, llvm::TargetMachine &targetMachine) {