diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -32,6 +32,10 @@ namespace llvm { class Type; class LLVMContext; +namespace sys { +template +class SmartMutex; +} // end namespace sys } // end namespace llvm namespace mlir { @@ -216,6 +220,12 @@ /// function confirms that the Operation has the desired properties. bool satisfiesLLVMModule(Operation *op); +/// Clones the given module into the provided context. This is implemented by +/// transforming the module into bitcode and then reparsing the bitcode in the +/// provided context. +std::unique_ptr +cloneModuleIntoNewContext(llvm::LLVMContext *context, llvm::Module *module); + } // end namespace LLVM } // end namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -24,6 +24,7 @@ ~LLVMDialect(); llvm::LLVMContext &getLLVMContext(); llvm::Module &getLLVMModule(); + llvm::sys::SmartMutex &getLLVMContextMutex(); private: friend LLVMType; diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -106,7 +106,6 @@ /// Original and translated module. Operation *mlirModule; std::unique_ptr llvmModule; - /// A converter for translating debug information. std::unique_ptr debugTranslation; @@ -114,6 +113,8 @@ std::unique_ptr ompBuilder; /// Precomputed pointer to OpenMP dialect. const Dialect *ompDialect; + /// Pointer to the llvmDialect; + LLVMDialect *llvmDialect; /// Mappings between llvm.mlir.global definitions and corresponding globals. DenseMap globalsMapping; diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" @@ -98,12 +99,19 @@ llvm::Module &module, llvm::TargetMachine &target_machine) { std::string ptx; { + // Clone the llvm module into a new context to enable concurrent compilation + // with multiple threads. + // TODO(zinenko): Reevaluate model of ownership of LLVMContext in + // LLVMDialect. + llvm::LLVMContext llvmContext; + auto clone = LLVM::cloneModuleIntoNewContext(&llvmContext, &module); + llvm::raw_string_ostream stream(ptx); llvm::buffer_ostream pstream(stream); llvm::legacy::PassManager codegen_passes; target_machine.addPassesToEmitFile(codegen_passes, pstream, nullptr, llvm::CGFT_AssemblyFile); - codegen_passes.run(module); + codegen_passes.run(*clone); } return ptx; diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -116,8 +116,8 @@ void addParamToList(OpBuilder &builder, Location loc, Value param, Value list, unsigned pos, Value one); Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); - Value generateKernelNameConstant(StringRef name, Location loc, - OpBuilder &builder); + Value generateKernelNameConstant(StringRef moduleName, StringRef name, + Location loc, OpBuilder &builder); void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); public: @@ -345,12 +345,13 @@ // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> // } Value GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( - StringRef name, Location loc, OpBuilder &builder) { + StringRef moduleName, StringRef name, Location loc, OpBuilder &builder) { // Make sure the trailing zero is included in the constant. std::vector kernelName(name.begin(), name.end()); kernelName.push_back('\0'); - std::string globalName = std::string(llvm::formatv("{0}_kernel_name", name)); + std::string globalName = + std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name)); return LLVM::createGlobalString( loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), LLVM::Linkage::Internal, llvmDialect); @@ -415,7 +416,8 @@ // the kernel function. auto cuOwningModuleRef = builder.create(loc, getPointerType(), cuModule); - auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder); + auto kernelName = generateKernelNameConstant(launchOp.getKernelModuleName(), + launchOp.kernel(), loc, builder); auto cuFunction = allocatePointer(builder, loc); auto cuModuleGetFunction = getOperation().lookupSymbol(cuModuleGetFunctionName); diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -13,6 +13,8 @@ target_link_libraries(MLIRLLVMIR PUBLIC LLVMAsmParser + LLVMBitReader + LLVMBitWriter LLVMCore LLVMSupport LLVMFrontendOpenMP diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -20,6 +20,8 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" @@ -1682,6 +1684,9 @@ llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } +llvm::sys::SmartMutex &LLVMDialect::getLLVMContextMutex() { + return impl->mutex; +} /// Parse a type registered to this dialect. Type LLVMDialect::parseType(DialectAsmParser &parser) const { @@ -1971,3 +1976,16 @@ return op->hasTrait() && op->hasTrait(); } + +std::unique_ptr +mlir::LLVM::cloneModuleIntoNewContext(llvm::LLVMContext *context, + llvm::Module *module) { + SmallVector buffer; + { + llvm::raw_svector_ostream os(buffer); + WriteBitcodeToFile(*module, os); + } + llvm::MemoryBufferRef bufferRef(StringRef(buffer.data(), buffer.size()), + "cloned module buffer"); + return cantFail(parseBitcodeFile(bufferRef, *context)); +} diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt --- a/mlir/lib/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/CMakeLists.txt @@ -17,8 +17,6 @@ PUBLIC MLIRLLVMIR MLIRTargetLLVMIR - LLVMBitReader - LLVMBitWriter LLVMExecutionEngine LLVMObject LLVMOrcJIT diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -11,13 +11,12 @@ // //===----------------------------------------------------------------------===// #include "mlir/ExecutionEngine/ExecutionEngine.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Target/LLVMIR.h" -#include "llvm/Bitcode/BitcodeReader.h" -#include "llvm/Bitcode/BitcodeWriter.h" #include "llvm/ExecutionEngine/JITEventListener.h" #include "llvm/ExecutionEngine/ObjectCache.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" @@ -211,17 +210,8 @@ // Clone module in a new LLVMContext since translateModuleToLLVMIR buries // ownership too deeply. // TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect. - SmallVector buffer; - { - llvm::raw_svector_ostream os(buffer); - WriteBitcodeToFile(*llvmModule, os); - } - llvm::MemoryBufferRef bufferRef(StringRef(buffer.data(), buffer.size()), - "cloned module buffer"); - auto expectedModule = parseBitcodeFile(bufferRef, *ctx); - if (!expectedModule) - return expectedModule.takeError(); - std::unique_ptr deserModule = std::move(*expectedModule); + std::unique_ptr deserModule = + LLVM::cloneModuleIntoNewContext(ctx.get(), llvmModule.get()); auto dataLayout = deserModule->getDataLayout(); // Callback to create the object layer with symbol resolution to current diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -301,7 +301,8 @@ debugTranslation( std::make_unique(module, *this->llvmModule)), ompDialect( - module->getContext()->getRegisteredDialect()) { + module->getContext()->getRegisteredDialect()), + llvmDialect(module->getContext()->getRegisteredDialect()) { assert(satisfiesLLVMModule(mlirModule) && "mlirModule should honor LLVM's module semantics."); } @@ -495,6 +496,9 @@ /// Create named global variables that correspond to llvm.mlir.global /// definitions. LogicalResult ModuleTranslation::convertGlobals() { + // Lock access to the llvm context. + llvm::sys::SmartScopedLock scopedLock( + llvmDialect->getLLVMContextMutex()); for (auto op : getModuleBody(mlirModule).getOps()) { llvm::Type *type = op.getType().getUnderlyingType(); llvm::Constant *cst = llvm::UndefValue::get(type); @@ -754,6 +758,9 @@ } LogicalResult ModuleTranslation::convertFunctions() { + // Lock access to the llvm context. + llvm::sys::SmartScopedLock scopedLock( + llvmDialect->getLLVMContextMutex()); // Declare all functions first because there may be function calls that form a // call graph with cycles. for (auto function : getModuleBody(mlirModule).getOps()) { @@ -798,6 +805,8 @@ ModuleTranslation::prepareLLVMModule(Operation *m) { auto *dialect = m->getContext()->getRegisteredDialect(); assert(dialect && "LLVM dialect must be registered"); + // Lock the LLVM context as we might create new types here. + llvm::sys::SmartScopedLock scopedLock(dialect->getLLVMContextMutex()); auto llvmModule = llvm::CloneModule(dialect->getLLVMModule()); if (!llvmModule) diff --git a/mlir/test/mlir-cuda-runner/two-modules.mlir b/mlir/test/mlir-cuda-runner/two-modules.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cuda-runner/two-modules.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-cuda-runner %s --print-ir-after-all --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --dump-input=always + +// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] +func @main() { + %arg = alloc() : memref<13xi32> + %dst = memref_cast %arg : memref<13xi32> to memref + %one = constant 1 : index + %sx = dim %dst, 0 : memref + call @mcuMemHostRegisterMemRef1dInt32(%dst) : (memref) -> () + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) + threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { + %t0 = index_cast %tx : index to i32 + store %t0, %dst[%tx] : memref + gpu.terminator + } + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) + threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { + %t0 = index_cast %tx : index to i32 + store %t0, %dst[%tx] : memref + gpu.terminator + } + %U = memref_cast %dst : memref to memref<*xi32> + call @print_memref_i32(%U) : (memref<*xi32>) -> () + return +} + +func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) +func @print_memref_i32(%ptr : memref<*xi32>)