diff --git a/mlir/include/mlir/IR/OwningOpRef.h b/mlir/include/mlir/IR/OwningOpRef.h --- a/mlir/include/mlir/IR/OwningOpRef.h +++ b/mlir/include/mlir/IR/OwningOpRef.h @@ -29,7 +29,7 @@ /// The underlying operation type stored in this reference. using OperationT = OpTy; - OwningOpRef(std::nullptr_t = nullptr) {} + OwningOpRef(std::nullptr_t = nullptr) : op(nullptr) {} OwningOpRef(OpTy op) : op(op) {} OwningOpRef(OwningOpRef &&other) : op(other.release()) {} ~OwningOpRef() { @@ -53,7 +53,7 @@ /// Release the referenced op. OpTy release() { - OpTy released; + OpTy released(nullptr); std::swap(released, op); return released; } diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -37,38 +37,48 @@ template inline OwningOpRef constructContainerOpForParserIfNecessary( Block *parsedBlock, MLIRContext *context, Location sourceFileLoc) { - static_assert( - ContainerOpT::template hasTrait() && - (ContainerOpT::template hasTrait() || - OpTrait::template hasSingleBlockImplicitTerminator< - ContainerOpT>::value), - "Expected `ContainerOpT` to have a single region with a single " - "block that has an implicit terminator or does not require one"); // Check to see if we parsed a single instance of this operation. if (llvm::hasSingleElement(*parsedBlock)) { - if (ContainerOpT op = dyn_cast(parsedBlock->front())) { + if (ContainerOpT op = dyn_cast(&parsedBlock->front())) { op->remove(); return op; } } - // If not, then build a new one to contain the parsed operations. - OpBuilder builder(context); - ContainerOpT op = builder.create(sourceFileLoc); - OwningOpRef opRef(op); - assert(op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0)) && - "expected generated operation to have a single region with a single " - "block"); - Block *opBlock = &op->getRegion(0).front(); - opBlock->getOperations().splice(opBlock->begin(), - parsedBlock->getOperations()); - - // After splicing, verify just this operation to ensure it can properly - // contain the operations inside of it. - if (failed(op.verifyInvariants())) - return OwningOpRef(); - return opRef; + // If not, then build a new top-level op if a concrete operation type was + // specified. + if constexpr (std::is_same_v) { + return emitError(sourceFileLoc) + << "source must contain a single top-level operation, found: " + << parsedBlock->getOperations().size(), + nullptr; + } else { + static_assert( + ContainerOpT::template hasTrait() && + (ContainerOpT::template hasTrait() || + OpTrait::template hasSingleBlockImplicitTerminator< + ContainerOpT>::value), + "Expected `ContainerOpT` to have a single region with a single " + "block that has an implicit terminator or does not require one"); + + OpBuilder builder(context); + ContainerOpT op = builder.create(sourceFileLoc); + OwningOpRef opRef(op); + assert(op->getNumRegions() == 1 && + llvm::hasSingleElement(op->getRegion(0)) && + "expected generated operation to have a single region with a single " + "block"); + Block *opBlock = &op->getRegion(0).front(); + opBlock->getOperations().splice(opBlock->begin(), + parsedBlock->getOperations()); + + // After splicing, verify just this operation to ensure it can properly + // contain the operations inside of it. + if (failed(op.verifyInvariants())) + return OwningOpRef(); + return opRef; + } } } // namespace detail @@ -195,6 +205,48 @@ &block, config.getContext(), sourceFileLoc); } +/// This parses the file specified by the indicated SourceMgr. If the source IR +/// contained a single operation, it is returned. If parsing was not successful, +/// null is returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef +parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) { + return detail::parseSourceFile(config, sourceMgr); +} + +/// This parses the file specified by the indicated filename. If the source IR +/// contained a single operation, it is returned. If parsing was not successful, +/// null is returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef parseSourceFile(StringRef filename, + const ParserConfig &config) { + return detail::parseSourceFile(config, filename); +} + +/// This parses the file specified by the indicated filename using the provided +/// SourceMgr. If the source IR contained a single operation, it is returned. If +/// parsing was not successful, null is returned and an error message is emitted +/// through the error handler registered in the context. +inline OwningOpRef parseSourceFile(StringRef filename, + llvm::SourceMgr &sourceMgr, + const ParserConfig &config) { + return detail::parseSourceFile(config, filename, sourceMgr); +} + +/// This parses the provided string containing MLIR. If the source IR contained +/// a single operation, it is returned. If parsing was not successful, null is +/// returned and an error message is emitted through the error handler +/// registered in the context. +inline OwningOpRef parseSourceString(StringRef sourceStr, + const ParserConfig &config) { + LocationAttr sourceFileLoc; + Block block; + if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc))) + return nullptr; + return detail::constructContainerOpForParserIfNecessary( + &block, config.getContext(), sourceFileLoc); +} + } // namespace mlir #endif // MLIR_PARSER_PARSER_H diff --git a/mlir/include/mlir/Tools/ParseUtilties.h b/mlir/include/mlir/Tools/ParseUtilties.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/ParseUtilties.h @@ -0,0 +1,38 @@ +//===- ParseUtilities.h - MLIR Tool Parse Utilities -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file containts common utilities for implementing the file-parsing +// behaviour for MLIR tools. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PARSEUTILITIES_H +#define MLIR_TOOLS_PARSEUTILITIES_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Parser/Parser.h" + +namespace mlir { +/// This parses the file specified by the indicated SourceMgr. If parsing was +/// not successful, null is returned and an error message is emitted through the +/// error handler registered in the context. +/// If 'insertImplicitModule' is true a top-level 'builtin.module' op will be +/// inserted that contains the parsed IR, unless one exists already. +inline OwningOpRef +parseSourceFileForTool(llvm::SourceMgr &sourceMgr, const ParserConfig &config, + bool insertImplicitModule) { + if (insertImplicitModule) + // TODO: Move implicit module logic out of 'parseSourceFile' and into here. + return parseSourceFile(sourceMgr, config) + .release() + .getOperation(); + return parseSourceFile(sourceMgr, config); +} +} // namespace mlir + +#endif // MLIR_TOOLS_PARSEUTILITIES_H diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -51,26 +51,24 @@ /// dialects from the global registry in the MLIRContext. This option is /// deprecated and will be removed soon. /// - emitBytecode will generate bytecode output instead of text. -LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, - std::unique_ptr buffer, - const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext = false, - bool emitBytecode = false); +/// - implicitModule will enable implicit addition of a top-level +/// 'builtin.module' if one doesn't already exist. +LogicalResult MlirOptMain( + llvm::raw_ostream &outputStream, std::unique_ptr buffer, + const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, + bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext = false, + bool emitBytecode = false, bool implicitModule = false); /// Support a callback to setup the pass manager. /// - passManagerSetupFn is the callback invoked to setup the pass manager to /// apply on the loaded IR. -LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, - std::unique_ptr buffer, - PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext = false, - bool emitBytecode = false); +LogicalResult MlirOptMain( + llvm::raw_ostream &outputStream, std::unique_ptr buffer, + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext = false, + bool emitBytecode = false, bool implicitModule = false); /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -27,6 +27,7 @@ #include "mlir/Support/FileUtilities.h" #include "mlir/Support/Timing.h" #include "mlir/Support/ToolUtilities.h" +#include "mlir/Tools/ParseUtilties.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/InitLLVM.h" @@ -49,7 +50,7 @@ bool verifyPasses, SourceMgr &sourceMgr, MLIRContext *context, PassPipelineFn passManagerSetupFn, - bool emitBytecode) { + bool emitBytecode, bool implicitModule) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -70,15 +71,16 @@ // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); - OwningOpRef module(parseSourceFile(sourceMgr, config)); + OwningOpRef op = + parseSourceFileForTool(sourceMgr, config, implicitModule); context->enableMultithreading(wasThreadingEnabled); - if (!module) + if (!op) return failure(); parserTiming.stop(); // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(context, OpPassManager::Nesting::Implicit, - module->getOperationName()); + op.get()->getName().getStringRef()); pm.enableVerifier(verifyPasses); applyPassManagerCLOptions(pm); pm.enableTiming(timing); @@ -86,18 +88,18 @@ return failure(); // Run the pipeline. - if (failed(pm.run(*module))) + if (failed(pm.run(*op))) return failure(); // Print the output. TimingScope outputTiming = timing.nest("Output"); if (emitBytecode) { BytecodeWriterConfig writerConfig(fallbackResourceMap); - writeBytecodeToFile(module->getOperation(), os, writerConfig); + writeBytecodeToFile(op.get(), os, writerConfig); } else { - AsmState asmState(*module, OpPrintingFlags(), /*locationMap=*/nullptr, + AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); - module->print(os, asmState); + op.get()->print(os, asmState); os << '\n'; } return success(); @@ -109,8 +111,9 @@ processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode, PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, llvm::ThreadPool *threadPool) { + bool emitBytecode, bool implicitModule, + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -134,7 +137,8 @@ if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passManagerSetupFn, emitBytecode); + &context, passManagerSetupFn, emitBytecode, + implicitModule); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -143,7 +147,7 @@ // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passManagerSetupFn, emitBytecode); + passManagerSetupFn, emitBytecode, implicitModule); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -157,7 +161,7 @@ bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode) { + bool emitBytecode, bool implicitModule) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -176,7 +180,7 @@ raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, emitBytecode, + preloadDialectsInContext, emitBytecode, implicitModule, passManagerSetupFn, registry, threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, @@ -190,7 +194,7 @@ bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode) { + bool emitBytecode, bool implicitModule) { auto passManagerSetupFn = [&](PassManager &pm) { auto errorHandler = [&](const Twine &msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; @@ -201,7 +205,7 @@ return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode); + emitBytecode, implicitModule); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -243,6 +247,12 @@ "emit-bytecode", cl::desc("Emit bytecode when generating output"), cl::init(false)); + static llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; + InitLLVM y(argc, argv); // Register any command line options. @@ -288,7 +298,7 @@ if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode))) + emitBytecode, /*implicitModule=*/!noImplicitModule))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/top-level.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt --no-implicit-module --verify-diagnostics --split-input-file %s | FileCheck %s + +// CHECK-NOT: module +// CHECK: func.func +func.func private @foo() + +// ----- + +// expected-error@-3 {{source must contain a single top-level operation, found: 2}} +func.func private @bar() +func.func private @baz() + +// ----- + +// expected-error@-3 {{source must contain a single top-level operation, found: 0}} diff --git a/mlir/test/Pass/pipeline-invalid.mlir b/mlir/test/Pass/pipeline-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/pipeline-invalid.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt --no-implicit-module --canonicalize --verify-diagnostics --split-input-file + +// expected-error@below {{trying to schedule a pass on an operation not marked as 'IsolatedFromAbove'}} +arith.constant 0 + +// ----- + +// expected-error@below {{trying to schedule a pass on an unregistered operation}} +"test.op"() : () -> ()