diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h --- a/mlir/include/mlir/Tools/mlir-translate/Translation.h +++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h @@ -25,28 +25,28 @@ class DialectRegistry; struct LogicalResult; class MLIRContext; -class ModuleOp; +class Operation; template class OwningOpRef; /// Interface of the function that translates the sources managed by `sourceMgr` /// to MLIR. The source manager has at least one buffer. The implementation -/// should create a new MLIR ModuleOp in the given context and return a pointer -/// to it, or a nullptr in case of any error. -using TranslateSourceMgrToMLIRFunction = std::function( +/// should create a new MLIR Operation in the given context and return a +/// pointer to it, or a nullptr in case of any error. +using TranslateSourceMgrToMLIRFunction = std::function( llvm::SourceMgr &sourceMgr, MLIRContext *)>; /// Interface of the function that translates the given string to MLIR. The -/// implementation should create a new MLIR ModuleOp in the given context. If +/// implementation should create a new MLIR Operation in the given context. If /// source-related error reporting is required from within the function, use /// TranslateSourceMgrToMLIRFunction instead. using TranslateStringRefToMLIRFunction = - std::function(llvm::StringRef, MLIRContext *)>; + std::function(llvm::StringRef, MLIRContext *)>; /// Interface of the function that translates MLIR to a different format and -/// outputs the result to a stream. It is allowed to modify the module. +/// outputs the result to a stream. It is allowed to modify the operation. using TranslateFromMLIRFunction = - std::function; + std::function; /// Interface of the function that performs file-to-file translation involving /// MLIR. The input file is held in the given MemoryBuffer; the output file diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -34,9 +34,9 @@ TranslateFromMLIRRegistration reg( "mlir-to-cpp", - [](ModuleOp module, raw_ostream &output) { + [](Operation *op, raw_ostream &output) { return emitc::translateToCpp( - module, output, + op, output, /*declareVariablesAtTop=*/declareVariablesAtTop); }, [](DialectRegistry ®istry) { diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1389,8 +1389,8 @@ // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the // LLVM dialect. -OwningOpRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, - MLIRContext *context) { +static OwningOpRef +translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, MLIRContext *context) { llvm::SMDiagnostic err; llvm::LLVMContext llvmContext; std::unique_ptr llvmModule = llvm::parseIR( @@ -1402,7 +1402,9 @@ emitError(UnknownLoc::get(context)) << errStream.str(); return {}; } - return translateLLVMIRToModule(std::move(llvmModule), context); + OwningOpRef module = + translateLLVMIRToModule(std::move(llvmModule), context); + return module.release().getOperation(); } namespace mlir { diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -25,9 +25,9 @@ void registerToLLVMIRTranslation() { TranslateFromMLIRRegistration registration( "mlir-to-llvmir", - [](ModuleOp module, raw_ostream &output) { + [](Operation *op, raw_ostream &output) { llvm::LLVMContext llvmContext; - auto llvmModule = translateModuleToLLVMIR(module, llvmContext); + auto llvmModule = translateModuleToLLVMIR(op, llvmContext); if (!llvmModule) return failure(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1183,7 +1183,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, StringRef name) { if (!satisfiesLLVMModule(module)) - return nullptr; + return module->emitOpError("can not be translated to an LLVMIR module"), + nullptr; std::unique_ptr llvmModule = prepareLLVMModule(module, llvmContext, name); diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -36,8 +36,8 @@ // Deserializes the SPIR-V binary module stored in the file named as // `inputFilename` and returns a module containing the SPIR-V module. -static OwningOpRef deserializeModule(const llvm::MemoryBuffer *input, - MLIRContext *context) { +static OwningOpRef +deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) { context->loadDialect(); // Make sure the input stream can be treated as a stream of SPIR-V words @@ -57,11 +57,11 @@ if (!spirvModule) return {}; - OwningOpRef module(ModuleOp::create(FileLineColLoc::get( - context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0))); - module->getBody()->push_front(spirvModule.release()); + auto module = ModuleOp::create(FileLineColLoc::get( + context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0)); + module.getBody()->push_front(spirvModule.release()); - return module; + return module.getOperation(); } namespace mlir { @@ -80,20 +80,20 @@ // Serialization registration //===----------------------------------------------------------------------===// -static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) { - if (!module) +static LogicalResult serializeModule(Operation *op, raw_ostream &output) { + if (!op) return failure(); SmallVector binary; SmallVector spirvModules; - module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); + op->walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); if (spirvModules.empty()) - return module.emitError("found no 'spirv.module' op"); + return op->emitError("found no 'spv.module' op"); if (spirvModules.size() != 1) - return module.emitError("found more than one 'spirv.module' op"); + return op->emitError("found more than one 'spv.module' op"); if (failed(spirv::serialize(spirvModules[0], binary))) return failure(); @@ -108,8 +108,8 @@ void registerToSPIRVTranslation() { TranslateFromMLIRRegistration toBinary( "serialize-spirv", - [](ModuleOp module, raw_ostream &output) { - return serializeModule(module, output); + [](Operation *op, raw_ostream &output) { + return serializeModule(op, output); }, [](DialectRegistry ®istry) { registry.insert(); @@ -121,17 +121,18 @@ // Round-trip registration //===----------------------------------------------------------------------===// -static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo, +static LogicalResult roundTripModule(Operation *src, bool emitDebugInfo, raw_ostream &output) { SmallVector binary; - MLIRContext *context = srcModule.getContext(); - auto spirvModules = srcModule.getOps(); + MLIRContext *context = src->getContext(); + SmallVector spirvModules; + src->walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); - if (spirvModules.begin() == spirvModules.end()) - return srcModule.emitError("found no 'spirv.module' op"); + if (spirvModules.empty()) + return src->emitError("found no 'spv.module' op"); - if (std::next(spirvModules.begin()) != spirvModules.end()) - return srcModule.emitError("found more than one 'spirv.module' op"); + if (spirvModules.size() > 1) + return src->emitError("found more than one 'spv.module' op"); spirv::SerializationOptions options; options.emitDebugInfo = emitDebugInfo; @@ -163,8 +164,8 @@ void registerTestRoundtripSPIRV() { TranslateFromMLIRRegistration roundtrip( "test-spirv-roundtrip", - [](ModuleOp module, raw_ostream &output) { - return roundTripModule(module, /*emitDebugInfo=*/false, output); + [](Operation *op, raw_ostream &output) { + return roundTripModule(op, /*emitDebugInfo=*/false, output); }, [](DialectRegistry ®istry) { registry.insert(); @@ -174,8 +175,8 @@ void registerTestRoundtripDebugSPIRV() { TranslateFromMLIRRegistration roundtrip( "test-spirv-roundtrip-debug", - [](ModuleOp module, raw_ostream &output) { - return roundTripModule(module, /*emitDebugInfo=*/true, output); + [](Operation *op, raw_ostream &output) { + return roundTripModule(op, /*emitDebugInfo=*/true, output); }, [](DialectRegistry ®istry) { registry.insert(); diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp --- a/mlir/lib/Tools/mlir-translate/Translation.cpp +++ b/mlir/lib/Tools/mlir-translate/Translation.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" +#include "mlir/Tools/ParseUtilties.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; @@ -58,10 +59,10 @@ StringRef name, const TranslateSourceMgrToMLIRFunction &function) { auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { - OwningOpRef module = function(sourceMgr, context); - if (!module || failed(verify(*module))) + OwningOpRef op = function(sourceMgr, context); + if (!op || failed(verify(*op))) return failure(); - module->print(output); + op.get()->print(output); return success(); }; registerTranslation(name, wrappedFn); @@ -91,16 +92,23 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( StringRef name, const TranslateFromMLIRFunction &function, const std::function &dialectRegistration) { + + static llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc("Disable the parsing of an implicit top-level module op"), + llvm::cl::init(false)}; + registerTranslation(name, [function, dialectRegistration]( llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { DialectRegistry registry; dialectRegistration(registry); context->appendDialectRegistry(registry); - auto module = parseSourceFile(sourceMgr, context); - if (!module || failed(verify(*module))) + OwningOpRef op = + parseSourceFileForTool(sourceMgr, context, !noImplicitModule); + if (!op || failed(verify(*op))) return failure(); - return function(module.get(), output); + return function(op.get(), output); }); } diff --git a/mlir/test/Target/LLVMIR/invalid-module.mlir b/mlir/test/Target/LLVMIR/invalid-module.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/invalid-module.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module %s + +// expected-error@below {{'llvm.func' op can not be translated to an LLVMIR module}} +llvm.func @foo() { + llvm.return +} diff --git a/mlir/test/Target/SPIRV/module.mlir b/mlir/test/Target/SPIRV/module.mlir --- a/mlir/test/Target/SPIRV/module.mlir +++ b/mlir/test/Target/SPIRV/module.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s +// RUN: mlir-translate -test-spirv-roundtrip --no-implicit-module -split-input-file %s | FileCheck %s // CHECK: spirv.module Logical GLSL450 requires #spirv.vce { // CHECK-NEXT: spirv.func @foo() "Inline" {