diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -25,26 +25,28 @@ struct LogicalResult; class MLIRContext; class ModuleOp; -class OwningModuleRef; +class Operation; +template +class OwningOpRef; /// Interface of the function that translates the sources managed by `sourceMgr` /// to MLIR. The source manager has at least one buffer. The implementation -/// should create a new MLIR ModuleOp in the given context and return a pointer -/// to it, or a nullptr in case of any error. -using TranslateSourceMgrToMLIRFunction = - std::function; +/// should create a new MLIR root operation in the given context and return a +/// pointer to it, or a nullptr in case of any error. +using TranslateSourceMgrToMLIRFunction = std::function( + llvm::SourceMgr &sourceMgr, MLIRContext *)>; /// Interface of the function that translates the given string to MLIR. The -/// implementation should create a new MLIR ModuleOp in the given context. If -/// source-related error reporting is required from within the function, use +/// implementation should create a new MLIR root operation in the given context. +/// If source-related error reporting is required from within the function, use /// TranslateSourceMgrToMLIRFunction instead. using TranslateStringRefToMLIRFunction = - std::function; + std::function(llvm::StringRef, MLIRContext *)>; /// Interface of the function that translates MLIR to a different format and /// outputs the result to a stream. It is allowed to modify the module. using TranslateFromMLIRFunction = - std::function; + std::function; /// Interface of the function that performs file-to-file translation involving /// MLIR. The input file is held in the given MemoryBuffer; the output file @@ -67,6 +69,12 @@ /// } /// } // namespace mlir /// +/// Note that the TranslateFROMMLIRRegistration usage has additional optional +/// arguments that are expected to be useful in special situations: +/// - dialectRegistration: A callback function to register dialects. +/// - parseFunction: A callback to perform parsing (to be used if the top +/// level module is not the builtin ModuleOp or special parser setup is +/// required). /// \{ struct TranslateToMLIRRegistration { TranslateToMLIRRegistration(llvm::StringRef name, @@ -76,10 +84,17 @@ }; struct TranslateFromMLIRRegistration { + using ParseFunction = OwningOpRef(llvm::SourceMgr &sourceMgr, + MLIRContext *context); + static OwningOpRef parseBuiltinModule(llvm::SourceMgr &sourceMgr, + MLIRContext *context); + TranslateFromMLIRRegistration( llvm::StringRef name, const TranslateFromMLIRFunction &function, std::function dialectRegistration = - [](DialectRegistry &) {}); + [](DialectRegistry &) {}, + ParseFunction parseFunction = + TranslateFromMLIRRegistration::parseBuiltinModule); }; struct TranslateRegistration { TranslateRegistration(llvm::StringRef name, diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -870,8 +870,8 @@ // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the // LLVM dialect. -OwningModuleRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, - MLIRContext *context) { +OwningOpRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, + MLIRContext *context) { llvm::SMDiagnostic err; llvm::LLVMContext llvmContext; std::unique_ptr llvmModule = llvm::parseIR( @@ -883,7 +883,8 @@ emitError(UnknownLoc::get(context)) << errStream.str(); return {}; } - return translateLLVMIRToModule(std::move(llvmModule), context); + return OwningOpRef( + translateLLVMIRToModule(std::move(llvmModule), context).release()); } namespace mlir { diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -23,7 +23,11 @@ void registerToLLVMIRTranslation() { TranslateFromMLIRRegistration registration( "mlir-to-llvmir", - [](ModuleOp module, raw_ostream &output) { + [](Operation *root, raw_ostream &output) -> LogicalResult { + auto module = dyn_cast(root); + if (!module) { + return emitError(root->getLoc()) << "expected module op"; + } llvm::LLVMContext llvmContext; auto llvmModule = translateModuleToLLVMIR(module, llvmContext); if (!llvmModule) diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -35,8 +35,8 @@ // Deserializes the SPIR-V binary module stored in the file named as // `inputFilename` and returns a module containing the SPIR-V module. -static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input, - MLIRContext *context) { +static OwningOpRef +deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) { context->loadDialect(); // Make sure the input stream can be treated as a stream of SPIR-V words @@ -53,14 +53,7 @@ OwningOpRef spirvModule = spirv::deserialize(binary, context); - if (!spirvModule) - return {}; - - OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( - context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0))); - module->getBody()->push_front(spirvModule.release()); - - return module; + return OwningOpRef(spirvModule.release()); } namespace mlir { @@ -79,23 +72,15 @@ // Serialization registration //===----------------------------------------------------------------------===// -static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) { - if (!module) +static LogicalResult serializeModule(Operation *root, raw_ostream &output) { + if (!root) return failure(); + auto spirvModule = dyn_cast(root); + if (!spirvModule) + return root->emitError("expected a spirv.module"); SmallVector binary; - - SmallVector spirvModules; - module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); - - if (spirvModules.empty()) - return module.emitError("found no 'spv.module' op"); - - if (spirvModules.size() != 1) - return module.emitError("found more than one 'spv.module' op"); - - if (failed( - spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false))) + if (failed(spirv::serialize(spirvModule, binary, /*emitDebuginfo=*/false))) return failure(); output.write(reinterpret_cast(binary.data()), @@ -107,11 +92,7 @@ namespace mlir { void registerToSPIRVTranslation() { TranslateFromMLIRRegistration toBinary( - "serialize-spirv", - [](ModuleOp module, raw_ostream &output) { - return serializeModule(module, output); - }, - [](DialectRegistry ®istry) { + "serialize-spirv", serializeModule, [](DialectRegistry ®istry) { registry.insert(); }); } @@ -121,60 +102,54 @@ // Round-trip registration //===----------------------------------------------------------------------===// -static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo, +static LogicalResult roundTripModule(Operation *srcRoot, bool emitDebugInfo, raw_ostream &output) { SmallVector binary; - MLIRContext *context = srcModule.getContext(); - auto spirvModules = srcModule.getOps(); + MLIRContext *context = srcRoot->getContext(); + auto srcSpirvModule = dyn_cast(srcRoot); + if (!srcSpirvModule) + return srcRoot->emitError("expected spirv.module op"); - if (spirvModules.begin() == spirvModules.end()) - return srcModule.emitError("found no 'spv.module' op"); - - if (std::next(spirvModules.begin()) != spirvModules.end()) - return srcModule.emitError("found more than one 'spv.module' op"); - - if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo))) + if (failed(spirv::serialize(srcSpirvModule, binary, emitDebugInfo))) return failure(); MLIRContext deserializationContext(context->getDialectRegistry()); // TODO: we should only load the required dialects instead of all dialects. deserializationContext.loadAllAvailableDialects(); // Then deserialize to get back a SPIR-V module. - OwningOpRef spirvModule = + OwningOpRef destSpirvModule = spirv::deserialize(binary, &deserializationContext); - if (!spirvModule) + if (!destSpirvModule) return failure(); - // Wrap around in a new MLIR module. - OwningModuleRef dstModule(ModuleOp::create( - FileLineColLoc::get(&deserializationContext, - /*filename=*/"", /*line=*/0, /*column=*/0))); - dstModule->getBody()->push_front(spirvModule.release()); - dstModule->print(output); + destSpirvModule->print(output); + return success(); +} - return mlir::success(); +static OwningOpRef parseSPIRVModule(llvm::SourceMgr &sourceMgr, + MLIRContext *context) { + return OwningOpRef( + parseSourceFile(sourceMgr, context).release()); } namespace mlir { void registerTestRoundtripSPIRV() { TranslateFromMLIRRegistration roundtrip( "test-spirv-roundtrip", - [](ModuleOp module, raw_ostream &output) { - return roundTripModule(module, /*emitDebugInfo=*/false, output); + [](Operation *root, raw_ostream &output) { + return roundTripModule(root, /*emitDebugInfo=*/false, output); }, - [](DialectRegistry ®istry) { - registry.insert(); - }); + [](DialectRegistry ®istry) { registry.insert(); }, + parseSPIRVModule); } void registerTestRoundtripDebugSPIRV() { TranslateFromMLIRRegistration roundtrip( "test-spirv-roundtrip-debug", - [](ModuleOp module, raw_ostream &output) { - return roundTripModule(module, /*emitDebugInfo=*/true, output); + [](Operation *root, raw_ostream &output) { + return roundTripModule(root, /*emitDebugInfo=*/true, output); }, - [](DialectRegistry ®istry) { - registry.insert(); - }); + [](DialectRegistry ®istry) { registry.insert(); }, + parseSPIRVModule); } } // namespace mlir diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -62,10 +62,10 @@ StringRef name, const TranslateSourceMgrToMLIRFunction &function) { auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { - OwningModuleRef module = function(sourceMgr, context); - if (!module || failed(verify(*module))) + OwningOpRef root = function(sourceMgr, context); + if (!root || failed(verify(*root))) return failure(); - module->print(output); + (*root)->print(output); return success(); }; registerTranslation(name, wrappedFn); @@ -94,20 +94,28 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( StringRef name, const TranslateFromMLIRFunction &function, - std::function dialectRegistration) { - registerTranslation(name, [function, dialectRegistration]( + std::function dialectRegistration, + ParseFunction parseFunction) { + registerTranslation(name, [function, dialectRegistration, parseFunction]( llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { DialectRegistry registry; dialectRegistration(registry); context->appendDialectRegistry(registry); - auto module = OwningModuleRef(parseSourceFile(sourceMgr, context)); - if (!module) + OwningOpRef root = parseFunction(sourceMgr, context); + if (!root) return failure(); - return function(module.get(), output); + return function(root.get(), output); }); } +OwningOpRef +TranslateFromMLIRRegistration::parseBuiltinModule(llvm::SourceMgr &sourceMgr, + MLIRContext *context) { + return OwningOpRef( + parseSourceFile(sourceMgr, context).release()); +} + //===----------------------------------------------------------------------===// // Translation Parser //===----------------------------------------------------------------------===//