diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -76,8 +76,10 @@ }; struct TranslateFromMLIRRegistration { - TranslateFromMLIRRegistration(llvm::StringRef name, - const TranslateFromMLIRFunction &function); + TranslateFromMLIRRegistration( + llvm::StringRef name, const TranslateFromMLIRFunction &function, + std::function dialectRegistration = + [](DialectRegistry &) {}); }; struct TranslateRegistration { TranslateRegistration(llvm::StringRef name, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -11,10 +11,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" @@ -105,8 +107,12 @@ namespace mlir { void registerToSPIRVTranslation() { TranslateFromMLIRRegistration toBinary( - "serialize-spirv", [](ModuleOp module, raw_ostream &output) { + "serialize-spirv", + [](ModuleOp module, raw_ostream &output) { return serializeModule(module, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); }); } } // namespace mlir @@ -147,15 +153,23 @@ namespace mlir { void registerTestRoundtripSPIRV() { TranslateFromMLIRRegistration roundtrip( - "test-spirv-roundtrip", [](ModuleOp module, raw_ostream &output) { + "test-spirv-roundtrip", + [](ModuleOp module, raw_ostream &output) { return roundTripModule(module, /*emitDebugInfo=*/false, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); }); } void registerTestRoundtripDebugSPIRV() { TranslateFromMLIRRegistration roundtrip( - "test-spirv-roundtrip-debug", [](ModuleOp module, raw_ostream &output) { + "test-spirv-roundtrip-debug", + [](ModuleOp module, raw_ostream &output) { return roundTripModule(module, /*emitDebugInfo=*/true, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -30,7 +30,8 @@ namespace mlir { void registerToLLVMIRTranslation() { TranslateFromMLIRRegistration registration( - "mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { + "mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = LLVM::ModuleTranslation::translateModule<>( module, llvmContext, "LLVMDialectModule"); @@ -39,6 +40,7 @@ llvmModule->print(output, nullptr); return success(); - }); + }, + [](DialectRegistry ®istry) { registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -99,7 +99,8 @@ namespace mlir { void registerToNVVMIRTranslation() { TranslateFromMLIRRegistration registration( - "mlir-to-nvvmir", [](ModuleOp module, raw_ostream &output) { + "mlir-to-nvvmir", + [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext); if (!llvmModule) @@ -107,6 +108,9 @@ llvmModule->print(output, nullptr); return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -103,7 +103,8 @@ namespace mlir { void registerToROCDLIRTranslation() { TranslateFromMLIRRegistration registration( - "mlir-to-rocdlir", [](ModuleOp module, raw_ostream &output) { + "mlir-to-rocdlir", + [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = mlir::translateModuleToROCDLIR(module, llvmContext); if (!llvmModule) @@ -111,6 +112,9 @@ llvmModule->print(output, nullptr); return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp --- a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp @@ -45,7 +45,8 @@ namespace mlir { void registerAVX512ToLLVMIRTranslation() { TranslateFromMLIRRegistration reg( - "avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { + "avx512-mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = translateLLVMAVX512ModuleToLLVMIR( module, llvmContext, "LLVMDialectModule"); @@ -54,6 +55,9 @@ llvmModule->print(output, nullptr); return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); }); } } // namespace mlir diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -92,10 +92,12 @@ //===----------------------------------------------------------------------===// TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( - StringRef name, const TranslateFromMLIRFunction &function) { - registerTranslation(name, [function](llvm::SourceMgr &sourceMgr, - raw_ostream &output, - MLIRContext *context) { + StringRef name, const TranslateFromMLIRFunction &function, + std::function dialectRegistration) { + registerTranslation(name, [function, dialectRegistration]( + llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { + dialectRegistration(context->getDialectRegistry()); auto module = OwningModuleRef(parseSourceFile(sourceMgr, context)); if (!module) return failure(); @@ -173,7 +175,7 @@ // Processes the memory buffer with a new MLIRContext. auto processBuffer = [&](std::unique_ptr ownedBuffer, raw_ostream &os) { - MLIRContext context; + MLIRContext context(false); context.printOpOnDiagnostic(!verifyDiagnostics); llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc()); diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -32,7 +32,5 @@ int main(int argc, char **argv) { registerAllTranslations(); registerTestTranslations(); - // TODO: remove the global dialect registry - registerAllDialects(); return failed(mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool")); }