diff --git a/mlir/include/mlir/IR/OwningOpRef.h b/mlir/include/mlir/IR/OwningOpRef.h --- a/mlir/include/mlir/IR/OwningOpRef.h +++ b/mlir/include/mlir/IR/OwningOpRef.h @@ -29,7 +29,7 @@ /// The underlying operation type stored in this reference. using OperationT = OpTy; - OwningOpRef(std::nullptr_t = nullptr) {} + OwningOpRef(std::nullptr_t = nullptr) : op(nullptr) {} OwningOpRef(OpTy op) : op(op) {} OwningOpRef(OwningOpRef &&other) : op(other.release()) {} ~OwningOpRef() { @@ -53,7 +53,7 @@ /// Release the referenced op. OpTy release() { - OpTy released; + OpTy released(nullptr); std::swap(released, op); return released; } diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -37,38 +37,48 @@ template inline OwningOpRef constructContainerOpForParserIfNecessary( Block *parsedBlock, MLIRContext *context, Location sourceFileLoc) { - static_assert( - ContainerOpT::template hasTrait() && - (ContainerOpT::template hasTrait() || - OpTrait::template hasSingleBlockImplicitTerminator< - ContainerOpT>::value), - "Expected `ContainerOpT` to have a single region with a single " - "block that has an implicit terminator or does not require one"); // Check to see if we parsed a single instance of this operation. if (llvm::hasSingleElement(*parsedBlock)) { - if (ContainerOpT op = dyn_cast(parsedBlock->front())) { + if (ContainerOpT op = dyn_cast(&parsedBlock->front())) { op->remove(); return op; } } - // If not, then build a new one to contain the parsed operations. - OpBuilder builder(context); - ContainerOpT op = builder.create(sourceFileLoc); - OwningOpRef opRef(op); - assert(op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0)) && - "expected generated operation to have a single region with a single " - "block"); - Block *opBlock = &op->getRegion(0).front(); - opBlock->getOperations().splice(opBlock->begin(), - parsedBlock->getOperations()); - - // After splicing, verify just this operation to ensure it can properly - // contain the operations inside of it. - if (failed(op.verifyInvariants())) - return OwningOpRef(); - return opRef; + // If not, then build a new top-level op if a concrete operation type was + // specified. + if constexpr (std::is_same_v) { + return emitError(sourceFileLoc) + << "source must contain a single top-level operation, found: " + << parsedBlock->getOperations().size(), + nullptr; + } else { + static_assert( + ContainerOpT::template hasTrait() && + (ContainerOpT::template hasTrait() || + OpTrait::template hasSingleBlockImplicitTerminator< + ContainerOpT>::value), + "Expected `ContainerOpT` to have a single region with a single " + "block that has an implicit terminator or does not require one"); + + OpBuilder builder(context); + ContainerOpT op = builder.create(sourceFileLoc); + OwningOpRef opRef(op); + assert(op->getNumRegions() == 1 && + llvm::hasSingleElement(op->getRegion(0)) && + "expected generated operation to have a single region with a single " + "block"); + Block *opBlock = &op->getRegion(0).front(); + opBlock->getOperations().splice(opBlock->begin(), + parsedBlock->getOperations()); + + // After splicing, verify just this operation to ensure it can properly + // contain the operations inside of it. + if (failed(op.verifyInvariants())) + return OwningOpRef(); + return opRef; + } } } // namespace detail @@ -195,6 +205,58 @@ &block, config.getContext(), sourceFileLoc); } +/// This parses the file specified by the indicated SourceMgr. If the source IR +/// contained a single operation, it is returned. If parsing was not successful, +/// null is returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef +parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) { + return detail::parseSourceFile(config, sourceMgr); +} + +/// This parses the file specified by the indicated filename. If the source IR +/// contained a single operation, it is returned. If parsing was not successful, +/// null is returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef parseSourceFile(StringRef filename, + const ParserConfig &config) { + return detail::parseSourceFile(config, filename); +} + +/// This parses the file specified by the indicated filename using the provided +/// SourceMgr. If the source IR contained a single operation, it is returned. If +/// parsing was not successful, null is returned and an error message is emitted +/// through the error handler registered in the context. +inline OwningOpRef parseSourceFile(StringRef filename, + llvm::SourceMgr &sourceMgr, + const ParserConfig &config) { + return detail::parseSourceFile(config, filename, sourceMgr); +} + +/// This parses the provided string containing MLIR. If the source IR contained +/// a single operation, it is returned. If parsing was not successful, null is +/// returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef parseSourceString(StringRef sourceStr, + const ParserConfig &config) { + LocationAttr sourceFileLoc; + Block block; + if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc))) + return nullptr; + return detail::constructContainerOpForParserIfNecessary( + &block, config.getContext(), sourceFileLoc); +} + +/// Register command-line options used to configure the behaviour of +/// `parseSourceFileForTool`. +void registerToolParserCLOptions(); + +/// This parses the file specified by the indicated SourceMgr. If parsing was +/// not successful, null is returned and an error message is emitted through the +/// error handler registered in the context. +OwningOpRef parseSourceFileForTool(llvm::SourceMgr &sourceMgr, + const ParserConfig &config); + } // namespace mlir #endif // MLIR_PARSER_PARSER_H diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -13,6 +13,9 @@ #include "mlir/Parser/Parser.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/IR/BuiltinOps.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; @@ -68,3 +71,29 @@ sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); return parseSourceFile(sourceMgr, block, config, sourceFileLoc); } + +namespace { +struct ToolParserOptions { + llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; +}; +} // namespace + +static llvm::ManagedStatic clOptions; + +void mlir::registerToolParserCLOptions() { + // Make sure that the options struct has been initialized. + *clOptions; +} + +OwningOpRef +mlir::parseSourceFileForTool(llvm::SourceMgr &sourceMgr, + const ParserConfig &config) { + assert(clOptions.isConstructed() && "options not registered"); + if (clOptions->noImplicitModule) + return parseSourceFile(sourceMgr, config); + return parseSourceFile(sourceMgr, config).release().getOperation(); +} diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -71,15 +71,15 @@ // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); - OwningOpRef module(parseSourceFile(sourceMgr, config)); + OwningOpRef op = parseSourceFileForTool(sourceMgr, config); context->enableMultithreading(wasThreadingEnabled); - if (!module) + if (!op) return failure(); parserTiming.stop(); // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(context, OpPassManager::Nesting::Implicit, - module->getOperationName()); + op.get()->getName().getStringRef()); pm.enableVerifier(verifyPasses); applyPassManagerCLOptions(pm); pm.enableTiming(timing); @@ -89,18 +89,18 @@ return failure(); // Run the pipeline. - if (failed(pm.run(*module))) + if (failed(pm.run(*op))) return failure(); // Print the output. TimingScope outputTiming = timing.nest("Output"); if (emitBytecode) { BytecodeWriterConfig writerConfig(fallbackResourceMap); - writeBytecodeToFile(module->getOperation(), os, writerConfig); + writeBytecodeToFile(op.get(), os, writerConfig); } else { - AsmState asmState(*module, OpPrintingFlags(), /*locationMap=*/nullptr, + AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); - module->print(os, asmState); + op.get()->print(os, asmState); os << '\n'; } return success(); @@ -253,6 +253,7 @@ registerMLIRContextCLOptions(); registerPassManagerCLOptions(); registerDefaultTimingManagerCLOptions(); + registerToolParserCLOptions(); DebugCounter::registerCLOptions(); PassPipelineCLParser passPipeline("", "Compiler passes to run"); diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/top-level.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt --no-implicit-module --verify-diagnostics --split-input-file %s | FileCheck %s + +// CHECK-NOT: module +// CHECK: func.func +func.func private @foo() + +// ----- + +// expected-error@-3 {{source must contain a single top-level operation, found: 2}} +func.func private @bar() +func.func private @baz() + +// ----- + +// expected-error@-3 {{source must contain a single top-level operation, found: 0}} diff --git a/mlir/test/Pass/pipeline-invalid.mlir b/mlir/test/Pass/pipeline-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/pipeline-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt --no-implicit-module --canonicalize --verify-diagnostics --split-input-file + +// expected-error@below {{trying to schedule a pass on an operation not marked as 'IsolatedFromAbove'}} +arith.constant 0 + +// ----- + +// expected-error@below {{trying to schedule a pass on an unregistered operation}} +"test.op"() : () -> ()