diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -76,8 +76,10 @@ }; struct TranslateFromMLIRRegistration { - TranslateFromMLIRRegistration(llvm::StringRef name, - const TranslateFromMLIRFunction &function); + TranslateFromMLIRRegistration( + llvm::StringRef name, const TranslateFromMLIRFunction &function, + std::function dialectRegistration = + [](DialectRegistry &) {}); }; struct TranslateRegistration { TranslateRegistration(llvm::StringRef name, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -11,10 +11,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" @@ -105,8 +107,12 @@ namespace mlir { void registerToSPIRVTranslation() { TranslateFromMLIRRegistration toBinary( - "serialize-spirv", [](ModuleOp module, raw_ostream &output) { + "serialize-spirv", + [](ModuleOp module, raw_ostream &output) { return serializeModule(module, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); }); } } // namespace mlir @@ -118,6 +124,7 @@ static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr, bool emitDebugInfo, raw_ostream &output, MLIRContext *context) { + context->getDialectRegistry().insert(); // Parse an MLIR module from the source manager. auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context)); if (!srcModule) diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -30,7 +30,8 @@ namespace mlir { void registerToLLVMIRTranslation() { TranslateFromMLIRRegistration registration( - "mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { + "mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = LLVM::ModuleTranslation::translateModule<>( module, llvmContext, "LLVMDialectModule"); @@ -39,6 +40,7 @@ llvmModule->print(output, nullptr); return success(); - }); + }, + [](DialectRegistry ®istry) { registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -99,7 +99,8 @@ namespace mlir { void registerToNVVMIRTranslation() { TranslateFromMLIRRegistration registration( - "mlir-to-nvvmir", [](ModuleOp module, raw_ostream &output) { + "mlir-to-nvvmir", + [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext); if (!llvmModule) @@ -107,6 +108,10 @@ llvmModule->print(output, nullptr); return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -103,7 +103,8 @@ namespace mlir { void registerToROCDLIRTranslation() { TranslateFromMLIRRegistration registration( - "mlir-to-rocdlir", [](ModuleOp module, raw_ostream &output) { + "mlir-to-rocdlir", + [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToROCDLIR(module, llvmContext); if (!llvmModule) @@ -111,6 +112,10 @@ llvmModule->print(output, nullptr); return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp @@ -45,7 +45,8 @@ namespace mlir { void registerAVX512ToLLVMIRTranslation() { TranslateFromMLIRRegistration reg( - "avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { + "avx512-mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = translateLLVMAVX512ModuleToLLVMIR( module, llvmContext, "LLVMDialectModule"); @@ -54,6 +55,10 @@ llvmModule->print(output, nullptr); return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -99,10 +99,12 @@ //===----------------------------------------------------------------------===// TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( - StringRef name, const TranslateFromMLIRFunction &function) { - registerTranslation(name, [function](llvm::SourceMgr &sourceMgr, - raw_ostream &output, - MLIRContext *context) { + StringRef name, const TranslateFromMLIRFunction &function, + std::function dialectRegistration) { + registerTranslation(name, [function, dialectRegistration]( + llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { + dialectRegistration(context->getDialectRegistry()); auto module = OwningModuleRef(parseSourceFile(sourceMgr, context)); if (!module) return failure();