diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -496,10 +496,17 @@ name, std::forward(parserFn))); } + /// Enable implicit addition of a top-level 'builtin.module' during parsing. + void enableImplicitModule(bool enabled = true) { implicitModule = enabled; } + + /// Return whether the parser should insert a top-level 'builtin.module'. + bool shouldInsertImplictModule() const { return implicitModule; } + private: MLIRContext *context; DenseMap> resourceParsers; FallbackAsmResourceMap *fallbackResourceMap; + bool implicitModule = false; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OwningOpRef.h b/mlir/include/mlir/IR/OwningOpRef.h --- a/mlir/include/mlir/IR/OwningOpRef.h +++ b/mlir/include/mlir/IR/OwningOpRef.h @@ -29,7 +29,7 @@ /// The underlying operation type stored in this reference. using OperationT = OpTy; - OwningOpRef(std::nullptr_t = nullptr) {} + OwningOpRef(std::nullptr_t = nullptr) : op(nullptr) {} OwningOpRef(OpTy op) : op(op) {} OwningOpRef(OwningOpRef &&other) : op(other.release()) {} ~OwningOpRef() { @@ -53,7 +53,7 @@ /// Release the referenced op. OpTy release() { - OpTy released; + OpTy released(nullptr); std::swap(released, op); return released; } diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -37,38 +37,48 @@ template inline OwningOpRef constructContainerOpForParserIfNecessary( Block *parsedBlock, MLIRContext *context, Location sourceFileLoc) { - static_assert( - ContainerOpT::template hasTrait() && - (ContainerOpT::template hasTrait() || - OpTrait::template hasSingleBlockImplicitTerminator< - ContainerOpT>::value), - "Expected `ContainerOpT` to have a single region with a single " - "block that has an implicit terminator or does not require one"); // Check to see if we parsed a single instance of this operation. if (llvm::hasSingleElement(*parsedBlock)) { - if (ContainerOpT op = dyn_cast(parsedBlock->front())) { + if (ContainerOpT op = dyn_cast(&parsedBlock->front())) { op->remove(); return op; } } - // If not, then build a new one to contain the parsed operations. - OpBuilder builder(context); - ContainerOpT op = builder.create(sourceFileLoc); - OwningOpRef opRef(op); - assert(op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0)) && - "expected generated operation to have a single region with a single " - "block"); - Block *opBlock = &op->getRegion(0).front(); - opBlock->getOperations().splice(opBlock->begin(), - parsedBlock->getOperations()); - - // After splicing, verify just this operation to ensure it can properly - // contain the operations inside of it. - if (failed(op.verifyInvariants())) - return OwningOpRef(); - return opRef; + // If not, then build a new top-level op if a concrete operation type was + // specified. + if constexpr (std::is_same_v) { + return emitError(sourceFileLoc) + << "source must contain a single top-level operation, found: " + << parsedBlock->getOperations().size(), + nullptr; + } else { + static_assert( + ContainerOpT::template hasTrait() && + (ContainerOpT::template hasTrait() || + OpTrait::template hasSingleBlockImplicitTerminator< + ContainerOpT>::value), + "Expected `ContainerOpT` to have a single region with a single " + "block that has an implicit terminator or does not require one"); + + OpBuilder builder(context); + ContainerOpT op = builder.create(sourceFileLoc); + OwningOpRef opRef(op); + assert(op->getNumRegions() == 1 && + llvm::hasSingleElement(op->getRegion(0)) && + "expected generated operation to have a single region with a single " + "block"); + Block *opBlock = &op->getRegion(0).front(); + opBlock->getOperations().splice(opBlock->begin(), + parsedBlock->getOperations()); + + // After splicing, verify just this operation to ensure it can properly + // contain the operations inside of it. + if (failed(op.verifyInvariants())) + return OwningOpRef(); + return opRef; + } } } // namespace detail @@ -195,6 +205,48 @@ &block, config.getContext(), sourceFileLoc); } +/// This parses the file specified by the indicated SourceMgr. If the source IR +/// contained a single operation, it is returned. If parsing was not successful, +/// null is returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef +parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) { + return detail::parseSourceFile(config, sourceMgr); +} + +/// This parses the file specified by the indicated filename. If the source IR +/// contained a single operation, it is returned. If parsing was not successful, +/// null is returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef parseSourceFile(StringRef filename, + const ParserConfig &config) { + return detail::parseSourceFile(config, filename); +} + +/// This parses the file specified by the indicated filename using the provided +/// SourceMgr. If the source IR contained a single operation, it is returned. If +/// parsing was not successful, null is returned and an error message is emitted +/// through the error handler registered in the context. +inline OwningOpRef parseSourceFile(StringRef filename, + llvm::SourceMgr &sourceMgr, + const ParserConfig &config) { + return detail::parseSourceFile(config, filename, sourceMgr); +} + +/// This parses the provided string containing MLIR. If the source IR contained +/// a single operation, it is returned. If parsing was not successful, null is +/// returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef parseSourceString(StringRef sourceStr, + const ParserConfig &config) { + LocationAttr sourceFileLoc; + Block block; + if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc))) + return nullptr; + return detail::constructContainerOpForParserIfNecessary( + &block, config.getContext(), sourceFileLoc); +} + } // namespace mlir #endif // MLIR_PARSER_PARSER_H diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -51,26 +51,24 @@ /// dialects from the global registry in the MLIRContext. This option is /// deprecated and will be removed soon. /// - emitBytecode will generate bytecode output instead of text. -LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, - std::unique_ptr buffer, - const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext = false, - bool emitBytecode = false); +/// - implicitModule will enable implicit addition of a top-level +/// 'builtin.module' doesn't already exist. +LogicalResult MlirOptMain( + llvm::raw_ostream &outputStream, std::unique_ptr buffer, + const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, + bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext = false, + bool emitBytecode = false, bool implicitModule = false); /// Support a callback to setup the pass manager. /// - passManagerSetupFn is the callback invoked to setup the pass manager to /// apply on the loaded IR. -LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, - std::unique_ptr buffer, - PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext = false, - bool emitBytecode = false); +LogicalResult MlirOptMain( + llvm::raw_ostream &outputStream, std::unique_ptr buffer, + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext = false, + bool emitBytecode = false, bool implicitModule = false); /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -2589,12 +2589,20 @@ if (opParser.finalize()) return failure(); - // Splice the blocks of the parsed operation over to the provided - // top-level block. + // Move the parsed operations over to the provided top-level block. If an + // implicit module op was requested we move 'topLevelOp' itself over, + // *unless* we already parsed a top-level module. auto &parsedOps = topLevelOp->getBody()->getOperations(); auto &destOps = topLevelBlock->getOperations(); - destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), - parsedOps, parsedOps.begin(), parsedOps.end()); + bool parsedTopLevelModule = + llvm::hasSingleElement(parsedOps) && isa(parsedOps.front()); + if (state.config.shouldInsertImplictModule() && !parsedTopLevelModule) { + destOps.push_back(topLevelOp.release()); + } else { + destOps.splice(destOps.empty() ? destOps.end() + : std::prev(destOps.end()), + parsedOps, parsedOps.begin(), parsedOps.end()); + } return success(); } diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -1414,6 +1414,11 @@ // Splice the parsed operations over to the provided top-level block. auto &parsedOps = moduleOp->getBody()->getOperations(); auto &destOps = block->getOperations(); + bool parsedTopLevelModule = + llvm::hasSingleElement(parsedOps) && isa(parsedOps.front()); + if (config.shouldInsertImplictModule() && !parsedTopLevelModule) + llvm::report_fatal_error( + "Inserting an implicit module is not supported when reading bytecode"); destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), parsedOps, parsedOps.begin(), parsedOps.end()); return success(); diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -49,7 +49,7 @@ bool verifyPasses, SourceMgr &sourceMgr, MLIRContext *context, PassPipelineFn passManagerSetupFn, - bool emitBytecode) { + bool emitBytecode, bool implicitModule) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -67,18 +67,19 @@ FallbackAsmResourceMap fallbackResourceMap; ParserConfig config(context, &fallbackResourceMap); reproOptions.attachResourceParser(config); + config.enableImplicitModule(implicitModule); // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); - OwningOpRef module(parseSourceFile(sourceMgr, config)); + OwningOpRef op = parseSourceFile(sourceMgr, config); context->enableMultithreading(wasThreadingEnabled); - if (!module) + if (!op) return failure(); parserTiming.stop(); // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(context, OpPassManager::Nesting::Implicit, - module->getOperationName()); + op.get()->getName().getStringRef()); pm.enableVerifier(verifyPasses); applyPassManagerCLOptions(pm); pm.enableTiming(timing); @@ -86,18 +87,18 @@ return failure(); // Run the pipeline. - if (failed(pm.run(*module))) + if (failed(pm.run(*op))) return failure(); // Print the output. TimingScope outputTiming = timing.nest("Output"); if (emitBytecode) { BytecodeWriterConfig writerConfig(fallbackResourceMap); - writeBytecodeToFile(module->getOperation(), os, writerConfig); + writeBytecodeToFile(op.get(), os, writerConfig); } else { - AsmState asmState(*module, OpPrintingFlags(), /*locationMap=*/nullptr, + AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); - module->print(os, asmState); + op.get()->print(os, asmState); os << '\n'; } return success(); @@ -109,8 +110,9 @@ processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode, PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, llvm::ThreadPool *threadPool) { + bool emitBytecode, bool implicitModule, + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -134,7 +136,8 @@ if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passManagerSetupFn, emitBytecode); + &context, passManagerSetupFn, emitBytecode, + implicitModule); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -143,7 +146,7 @@ // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passManagerSetupFn, emitBytecode); + passManagerSetupFn, emitBytecode, implicitModule); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -157,7 +160,7 @@ bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode) { + bool emitBytecode, bool implicitModule) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -176,7 +179,7 @@ raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, emitBytecode, + preloadDialectsInContext, emitBytecode, implicitModule, passManagerSetupFn, registry, threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, @@ -190,7 +193,7 @@ bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode) { + bool emitBytecode, bool implicitModule) { auto passManagerSetupFn = [&](PassManager &pm) { auto errorHandler = [&](const Twine &msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; @@ -201,7 +204,7 @@ return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode); + emitBytecode, implicitModule); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -243,6 +246,12 @@ "emit-bytecode", cl::desc("Emit bytecode when generating output"), cl::init(false)); + static llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; + InitLLVM y(argc, argv); // Register any command line options. @@ -288,7 +297,7 @@ if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode))) + emitBytecode, /*implicitModule=*/!noImplicitModule))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/top-level.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt --no-implicit-module --verify-diagnostics --split-input-file %s | FileCheck %s + +// CHECK-NOT: module +// CHECK: func.func +func.func private @foo() + +// ----- + +// expected-error@-3 {{source must contain a single top-level operation, found: 2}} +func.func private @bar() +func.func private @baz() + +// ----- + +// expected-error@-3 {{source must contain a single top-level operation, found: 0}} diff --git a/mlir/test/Pass/pipeline-invalid.mlir b/mlir/test/Pass/pipeline-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/pipeline-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt --no-implicit-module --canonicalize --verify-diagnostics --split-input-file + +// expected-error@below {{trying to schedule a pass on an operation not marked as 'IsolatedFromAbove'}} +arith.constant 0 + +// ----- + +// expected-error@below {{trying to schedule a pass on an unregistered operation}} +"test.op"() : () -> ()