diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -238,14 +238,17 @@ } ``` -The full code listing for dumping LLVM IR can be found in +The full code listing for dumping LLVM IR can be found in `examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function: ```c++ int dumpLLVMIR(mlir::ModuleOp module) { - // Translate the module, that contains the LLVM dialect, to LLVM IR. - auto llvmModule = mlir::translateModuleToLLVMIR(module); + // Translate the module, that contains the LLVM dialect, to LLVM IR. Use a + // fresh LLVM IR context. (Note that LLVM is not thread-safe and any + // concurrent use of a context requires external locking.) + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); if (!llvmModule) { llvm::errs() << "Failed to emit LLVM IR\n"; return -1; diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -189,7 +189,9 @@ } int dumpLLVMIR(mlir::ModuleOp module) { - auto llvmModule = mlir::translateModuleToLLVMIR(module); + // Convert the module to LLVM IR in a new LLVM IR context. + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); if (!llvmModule) { llvm::errs() << "Failed to emit LLVM IR\n"; return -1; diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -190,7 +190,9 @@ } int dumpLLVMIR(mlir::ModuleOp module) { - auto llvmModule = mlir::translateModuleToLLVMIR(module); + // Convert the module to LLVM IR in a new LLVM IR context. + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); if (!llvmModule) { llvm::errs() << "Failed to emit LLVM IR\n"; return -1; diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -35,8 +35,8 @@ using OwnedBlob = std::unique_ptr>; using BlobGenerator = std::function; -using LoweringCallback = - std::function(Operation *)>; +using LoweringCallback = std::function( + Operation *, llvm::LLVMContext &, StringRef)>; /// Creates a pass to convert a gpu.launch_func operation into a sequence of /// GPU runtime calls. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -65,12 +65,6 @@ /// function confirms that the Operation has the desired properties. bool satisfiesLLVMModule(Operation *op); -/// Clones the given module into the provided context. This is implemented by -/// transforming the module into bitcode and then reparsing the bitcode in the -/// provided context. -std::unique_ptr -cloneModuleIntoNewContext(llvm::LLVMContext *context, llvm::Module *module); - } // end namespace LLVM } // end namespace mlir diff --git a/mlir/include/mlir/Target/LLVMIR.h b/mlir/include/mlir/Target/LLVMIR.h --- a/mlir/include/mlir/Target/LLVMIR.h +++ b/mlir/include/mlir/Target/LLVMIR.h @@ -13,6 +13,8 @@ #ifndef MLIR_TARGET_LLVMIR_H #define MLIR_TARGET_LLVMIR_H +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" #include // Forward-declare LLVM classes. @@ -31,7 +33,9 @@ /// from the registered LLVM IR dialect. In case of error, report it /// to the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. -std::unique_ptr translateModuleToLLVMIR(ModuleOp m); +std::unique_ptr +translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext, + StringRef name = "LLVMDialectModule"); /// Convert the given LLVM module into MLIR's LLVM dialect. The LLVM context is /// extracted from the registered LLVM IR dialect. In case of error, report it diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -50,14 +50,15 @@ class ModuleTranslation { public: template - static std::unique_ptr translateModule(Operation *m) { + static std::unique_ptr + translateModule(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name = "LLVMDialectModule") { if (!satisfiesLLVMModule(m)) return nullptr; if (failed(checkSupportedModuleOps(m))) return nullptr; - auto llvmModule = prepareLLVMModule(m); - if (!llvmModule) - return nullptr; + std::unique_ptr llvmModule = + prepareLLVMModule(m, llvmContext, name); LLVM::ensureDistinctSuccessors(m); @@ -94,7 +95,9 @@ /// Converts the type from MLIR LLVM dialect to LLVM. llvm::Type *convertType(LLVMType type); - static std::unique_ptr prepareLLVMModule(Operation *m); + static std::unique_ptr + prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name); /// A helper to look up remapped operands in the value remapping table. SmallVector lookupValues(ValueRange values); @@ -122,8 +125,6 @@ std::unique_ptr ompBuilder; /// Precomputed pointer to OpenMP dialect. const Dialect *ompDialect; - /// Pointer to the llvmDialect; - LLVMDialect *llvmDialect; /// Mappings between llvm.mlir.global definitions and corresponding globals. DenseMap globalsMapping; diff --git a/mlir/include/mlir/Target/NVVMIR.h b/mlir/include/mlir/Target/NVVMIR.h --- a/mlir/include/mlir/Target/NVVMIR.h +++ b/mlir/include/mlir/Target/NVVMIR.h @@ -13,10 +13,12 @@ #ifndef MLIR_TARGET_NVVMIR_H #define MLIR_TARGET_NVVMIR_H +#include "llvm/ADT/StringRef.h" #include // Forward-declare LLVM classes. namespace llvm { +class LLVMContext; class Module; } // namespace llvm @@ -28,7 +30,9 @@ /// context from the registered LLVM IR dialect. In case of error, report it to /// the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. -std::unique_ptr translateModuleToNVVMIR(Operation *m); +std::unique_ptr +translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext, + llvm::StringRef name = "LLVMDialectModule"); } // namespace mlir diff --git a/mlir/include/mlir/Target/ROCDLIR.h b/mlir/include/mlir/Target/ROCDLIR.h --- a/mlir/include/mlir/Target/ROCDLIR.h +++ b/mlir/include/mlir/Target/ROCDLIR.h @@ -14,10 +14,13 @@ #ifndef MLIR_TARGET_ROCDLIR_H #define MLIR_TARGET_ROCDLIR_H +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/LLVMContext.h" #include // Forward-declare LLVM classes. namespace llvm { +class LLVMContext; class Module; } // namespace llvm @@ -29,7 +32,9 @@ /// context from the registered LLVM IR dialect. In case of error, report it to /// the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. -std::unique_ptr translateModuleToROCDLIR(Operation *m); +std::unique_ptr +translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext, + llvm::StringRef name = "LLVMDialectModule"); } // namespace mlir diff --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp @@ -59,14 +59,11 @@ void runOnOperation() override { gpu::GPUModuleOp module = getOperation(); - // Lock access to the llvm context. - llvm::sys::SmartScopedLock scopedLock( - module.getContext() - ->getRegisteredDialect() - ->getLLVMContextMutex()); - - // Lower the module to a llvm module. - std::unique_ptr llvmModule = loweringCallback(module); + // Lower the module to an LLVM IR module using a separate context to enable + // multi-threaded processing. + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = + loweringCallback(module, llvmContext, "LLVMDialectModule"); if (!llvmModule) return signalPassFailure(); @@ -109,17 +106,12 @@ llvm::TargetMachine &targetMachine) { std::string targetISA; { - // Clone the llvm module into a new context to enable concurrent compilation - // with multiple threads. - llvm::LLVMContext llvmContext; - auto clone = LLVM::cloneModuleIntoNewContext(&llvmContext, &module); - llvm::raw_string_ostream stream(targetISA); llvm::buffer_ostream pstream(stream); llvm::legacy::PassManager codegenPasses; targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr, llvm::CGFT_AssemblyFile); - codegenPasses.run(*clone); + codegenPasses.run(module); } return targetISA; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1794,16 +1794,3 @@ return op->hasTrait() && op->hasTrait(); } - -std::unique_ptr -mlir::LLVM::cloneModuleIntoNewContext(llvm::LLVMContext *context, - llvm::Module *module) { - SmallVector buffer; - { - llvm::raw_svector_ostream os(buffer); - WriteBitcodeToFile(*module, os); - } - llvm::MemoryBufferRef bufferRef(StringRef(buffer.data(), buffer.size()), - "cloned module buffer"); - return cantFail(parseBitcodeFile(bufferRef, *context)); -} diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -223,7 +223,7 @@ enablePerfNotificationListener); std::unique_ptr ctx(new llvm::LLVMContext); - auto llvmModule = translateModuleToLLVMIR(m); + auto llvmModule = translateModuleToLLVMIR(m, *ctx); if (!llvmModule) return make_string_error("could not convert to LLVM IR"); // FIXME: the triple should be passed to the translation or dialect conversion @@ -232,12 +232,7 @@ setupTargetTriple(llvmModule.get()); packFunctionArguments(llvmModule.get()); - // Clone module in a new LLVMContext since translateModuleToLLVMIR buries - // ownership too deeply. - // TODO: Reevaluate model of ownership of LLVMContext in LLVMDialect. - std::unique_ptr deserModule = - LLVM::cloneModuleIntoNewContext(ctx.get(), llvmModule.get()); - auto dataLayout = deserModule->getDataLayout(); + auto dataLayout = llvmModule->getDataLayout(); // Callback to create the object layer with symbol resolution to current // process and dynamically linked libraries. @@ -295,7 +290,7 @@ .create()); // Add a ThreadSafemodule to the engine and return. - ThreadSafeModule tsm(std::move(deserModule), std::move(ctx)); + ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx)); if (transformer) cantFail(tsm.withModuleDo( [&](llvm::Module &module) { return transformer(&module); })); diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -21,15 +21,19 @@ using namespace mlir; -std::unique_ptr mlir::translateModuleToLLVMIR(ModuleOp m) { - return LLVM::ModuleTranslation::translateModule<>(m); +std::unique_ptr +mlir::translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext, + StringRef name) { + return LLVM::ModuleTranslation::translateModule<>(m, llvmContext, name); } namespace mlir { void registerToLLVMIRTranslation() { TranslateFromMLIRRegistration registration( "mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { - auto llvmModule = LLVM::ModuleTranslation::translateModule<>(module); + llvm::LLVMContext llvmContext; + auto llvmModule = LLVM::ModuleTranslation::translateModule<>( + module, llvmContext, "LLVMDialectModule"); if (!llvmModule) return failure(); diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -65,9 +65,11 @@ }; } // namespace -std::unique_ptr mlir::translateModuleToNVVMIR(Operation *m) { - auto llvmModule = - LLVM::ModuleTranslation::translateModule(m); +std::unique_ptr +mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { + auto llvmModule = LLVM::ModuleTranslation::translateModule( + m, llvmContext, name); if (!llvmModule) return llvmModule; @@ -98,7 +100,8 @@ void registerToNVVMIRTranslation() { TranslateFromMLIRRegistration registration( "mlir-to-nvvmir", [](ModuleOp module, raw_ostream &output) { - auto llvmModule = mlir::translateModuleToNVVMIR(module); + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext); if (!llvmModule) return failure(); diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -75,10 +75,12 @@ }; } // namespace -std::unique_ptr mlir::translateModuleToROCDLIR(Operation *m) { +std::unique_ptr +mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics) - auto llvmModule = - LLVM::ModuleTranslation::translateModule(m); + auto llvmModule = LLVM::ModuleTranslation::translateModule( + m, llvmContext, name); // foreach GPU kernel // 1. Insert AMDGPU_KERNEL calling convention. @@ -102,7 +104,8 @@ void registerToROCDLIRTranslation() { TranslateFromMLIRRegistration registration( "mlir-to-rocdlir", [](ModuleOp module, raw_ostream &output) { - auto llvmModule = mlir::translateModuleToROCDLIR(module); + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToROCDLIR(module, llvmContext); if (!llvmModule) return failure(); diff --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp @@ -34,9 +34,11 @@ } }; -std::unique_ptr translateLLVMAVX512ModuleToLLVMIR(Operation *m) { +std::unique_ptr +translateLLVMAVX512ModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { return LLVM::ModuleTranslation::translateModule( - m); + m, llvmContext, name); } } // end namespace @@ -44,7 +46,9 @@ void registerAVX512ToLLVMIRTranslation() { TranslateFromMLIRRegistration reg( "avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { - auto llvmModule = translateLLVMAVX512ModuleToLLVMIR(module); + llvm::LLVMContext llvmContext; + auto llvmModule = translateLLVMAVX512ModuleToLLVMIR( + module, llvmContext, "LLVMDialectModule"); if (!llvmModule) return failure(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -304,7 +304,6 @@ std::make_unique(module, *this->llvmModule)), ompDialect( module->getContext()->getRegisteredDialect()), - llvmDialect(module->getContext()->getRegisteredDialect()), typeTranslator(this->llvmModule->getContext()) { assert(satisfiesLLVMModule(mlirModule) && "mlirModule should honor LLVM's module semantics."); @@ -688,9 +687,6 @@ /// Create named global variables that correspond to llvm.mlir.global /// definitions. LogicalResult ModuleTranslation::convertGlobals() { - // Lock access to the llvm context. - llvm::sys::SmartScopedLock scopedLock( - llvmDialect->getLLVMContextMutex()); for (auto op : getModuleBody(mlirModule).getOps()) { llvm::Type *type = convertType(op.getType()); llvm::Constant *cst = llvm::UndefValue::get(type); @@ -892,10 +888,6 @@ } LogicalResult ModuleTranslation::convertFunctionSignatures() { - // Lock access to the llvm context. - llvm::sys::SmartScopedLock scopedLock( - llvmDialect->getLLVMContextMutex()); - // Declare all functions first because there may be function calls that form a // call graph with cycles, or global initializers that reference functions. for (auto function : getModuleBody(mlirModule).getOps()) { @@ -916,10 +908,6 @@ } LogicalResult ModuleTranslation::convertFunctions() { - // Lock access to the llvm context. - llvm::sys::SmartScopedLock scopedLock( - llvmDialect->getLLVMContextMutex()); - // Convert functions. for (auto function : getModuleBody(mlirModule).getOps()) { // Ignore external functions. @@ -934,8 +922,6 @@ } llvm::Type *ModuleTranslation::convertType(LLVMType type) { - // Lock the LLVM context as we create types in it. - llvm::sys::SmartScopedLock lock(llvmDialect->getLLVMContextMutex()); return typeTranslator.translateType(type); } @@ -951,22 +937,17 @@ return remapped; } -std::unique_ptr -ModuleTranslation::prepareLLVMModule(Operation *m) { +std::unique_ptr ModuleTranslation::prepareLLVMModule( + Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { auto *dialect = m->getContext()->getRegisteredDialect(); assert(dialect && "LLVM dialect must be registered"); - // Lock the LLVM context as we might create new types here. - llvm::sys::SmartScopedLock scopedLock(dialect->getLLVMContextMutex()); - - auto llvmModule = llvm::CloneModule(dialect->getLLVMModule()); - if (!llvmModule) - return nullptr; - llvm::LLVMContext &llvmContext = llvmModule->getContext(); - llvm::IRBuilder<> builder(llvmContext); + auto llvmModule = std::make_unique(name, llvmContext); + llvmModule->setDataLayout(dialect->getDataLayout()); // Inject declarations for `malloc` and `free` functions that can be used in // memref allocation/deallocation coming from standard ops lowering. + llvm::IRBuilder<> builder(llvmContext); llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(), builder.getInt64Ty()); llvmModule->getOrInsertFunction("free", builder.getVoidTy(), diff --git a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp --- a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp +++ b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp @@ -196,8 +196,10 @@ return success(); } -static std::unique_ptr compileModuleToROCDLIR(Operation *m) { - auto llvmModule = translateModuleToROCDLIR(m); +static std::unique_ptr +compileModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { + auto llvmModule = translateModuleToROCDLIR(m, llvmContext, name); // TODO: Link with ROCm-Device-Libs in case needed (ex: the Module // depends on math functions). return llvmModule;