diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -34,7 +34,7 @@ namespace mlir { -class ModuleOp; +class Operation; /// A simple object cache following Lang's LLJITWithObjectCache example. class SimpleObjectCache : public llvm::ObjectCache { @@ -51,10 +51,10 @@ }; struct ExecutionEngineOptions { - /// If `llvmModuleBuilder` is provided, it will be used to create LLVM module - /// from the given MLIR module. Otherwise, a default `translateModuleToLLVMIR` - /// function will be used to translate MLIR module to LLVM IR. - llvm::function_ref(ModuleOp, + /// If `llvmModuleBuilder` is provided, it will be used to create an LLVM + /// module from the given MLIR IR. Otherwise, a default + /// `translateModuleToLLVMIR` function will be used to translate to LLVM IR. + llvm::function_ref(Operation *, llvm::LLVMContext &)> llvmModuleBuilder = nullptr; @@ -89,9 +89,9 @@ bool enablePerfNotificationListener = true; }; -/// JIT-backed execution engine for MLIR modules. Assumes the module can be -/// converted to LLVM IR. For each function, creates a wrapper function with -/// the fixed interface +/// JIT-backed execution engine for MLIR. Assumes the IR can be converted to +/// LLVM IR. For each function, creates a wrapper function with the fixed +/// interface /// /// void _mlir_funcName(void **) /// @@ -104,9 +104,9 @@ ExecutionEngine(bool enableObjectCache, bool enableGDBNotificationListener, bool enablePerfNotificationListener); - /// Creates an execution engine for the given module. + /// Creates an execution engine for the given MLIR IR. static llvm::Expected> - create(ModuleOp m, const ExecutionEngineOptions &options = {}); + create(Operation *op, const ExecutionEngineOptions &options = {}); /// Looks up a packed-argument function wrapping the function with the given /// name and returns a pointer to it. Propagates errors in case of failure. diff --git a/mlir/include/mlir/ExecutionEngine/JitRunner.h b/mlir/include/mlir/ExecutionEngine/JitRunner.h --- a/mlir/include/mlir/ExecutionEngine/JitRunner.h +++ b/mlir/include/mlir/ExecutionEngine/JitRunner.h @@ -33,17 +33,18 @@ namespace mlir { class DialectRegistry; -class ModuleOp; +class Operation; struct LogicalResult; struct JitRunnerConfig { /// MLIR transformer applied after parsing the input into MLIR IR and before - /// passing the MLIR module to the ExecutionEngine. - llvm::function_ref mlirTransformer = nullptr; + /// passing the MLIR IR to the ExecutionEngine. + llvm::function_ref mlirTransformer = + nullptr; - /// A custom function that is passed to ExecutionEngine. It processes MLIR - /// module and creates LLVM IR module. - llvm::function_ref(ModuleOp, + /// A custom function that is passed to ExecutionEngine. It processes MLIR and + /// creates an LLVM IR module. + llvm::function_ref(Operation *, llvm::LLVMContext &)> llvmModuleBuilder = nullptr; diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -232,7 +232,7 @@ } Expected> -ExecutionEngine::create(ModuleOp m, const ExecutionEngineOptions &options) { +ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options) { auto engine = std::make_unique( options.enableObjectCache, options.enableGDBNotificationListener, options.enablePerfNotificationListener); diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Tools/ParseUtilties.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" @@ -91,6 +92,12 @@ llvm::cl::opt hostSupportsJit{"host-supports-jit", llvm::cl::desc("Report host JIT support"), llvm::cl::Hidden}; + + llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; }; struct CompileAndExecuteConfig { @@ -99,7 +106,7 @@ /// A custom function that is passed to ExecutionEngine. It processes MLIR /// module and creates LLVM IR module. - llvm::function_ref(ModuleOp, + llvm::function_ref(Operation *, llvm::LLVMContext &)> llvmModuleBuilder; @@ -111,8 +118,9 @@ } // namespace -static OwningOpRef parseMLIRInput(StringRef inputFilename, - MLIRContext *context) { +static OwningOpRef parseMLIRInput(StringRef inputFilename, + bool insertImplicitModule, + MLIRContext *context) { // Set up the input file. std::string errorMessage; auto file = openInputFile(inputFilename, &errorMessage); @@ -123,7 +131,15 @@ llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); - return parseSourceFile(sourceMgr, context); + OwningOpRef module = + parseSourceFileForTool(sourceMgr, context, insertImplicitModule); + if (!module) + return nullptr; + if (!module.get()->hasTrait()) { + llvm::errs() << "Error: top-level op must be a symbol table.\n"; + return nullptr; + } + return module; } static inline Error makeStringError(const Twine &message) { @@ -148,7 +164,7 @@ } // JIT-compile the given module and run "entryPoint" with "args" as arguments. -static Error compileAndExecute(Options &options, ModuleOp module, +static Error compileAndExecute(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, void **args) { Optional jitCodeGenOptLevel; @@ -240,10 +256,11 @@ return Error::success(); } -static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module, +static Error compileAndExecuteVoidFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config) { - auto mainFunction = module.lookupSymbol(entryPoint); + auto mainFunction = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, entryPoint)); if (!mainFunction || mainFunction.empty()) return makeStringError("entry point not found"); void *empty = nullptr; @@ -283,10 +300,11 @@ return Error::success(); } template -Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module, +Error compileAndExecuteSingleReturnFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config) { - auto mainFunction = module.lookupSymbol(entryPoint); + auto mainFunction = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, entryPoint)); if (!mainFunction || mainFunction.isExternal()) return makeStringError("entry point not found"); @@ -339,7 +357,8 @@ MLIRContext context(registry); - auto m = parseMLIRInput(options.inputFilename, &context); + auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule, + &context); if (!m) { llvm::errs() << "could not parse the input IR\n"; return 1; @@ -370,7 +389,7 @@ // Get the function used to compile and execute the module. using CompileAndExecuteFnT = - Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig); + Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig); auto compileAndExecuteFn = StringSwitch(options.mainFuncType.getValue()) .Case("i32", compileAndExecuteSingleReturnFunction) diff --git a/mlir/test/mlir-cpu-runner/invalid.mlir b/mlir/test/mlir-cpu-runner/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/invalid.mlir @@ -0,0 +1,4 @@ +// RUN: not mlir-cpu-runner --no-implicit-module %s |& FileCheck %s + +// CHECK: Error: top-level op must be a symbol table. +llvm.func @main() diff --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp --- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp +++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp @@ -51,7 +51,10 @@ /// Each of these two modules is translated to LLVM IR module, then they are /// linked together and returned. static std::unique_ptr -convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) { +convertMLIRModule(Operation *op, llvm::LLVMContext &context) { + auto module = dyn_cast(op); + if (!module) + return op->emitError("op must be a 'builtin.module"), nullptr; // Verify that there is only one nested module. auto modules = module.getOps(); if (!llvm::hasSingleElement(modules)) { @@ -71,8 +74,9 @@ return mainModule; } -static LogicalResult runMLIRPasses(ModuleOp module) { - PassManager passManager(module.getContext()); +static LogicalResult runMLIRPasses(Operation *module) { + PassManager passManager(module->getContext(), + module->getName().getStringRef()); applyPassManagerCLOptions(passManager); passManager.addPass(createGpuKernelOutliningPass()); passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));