diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -1103,7 +1103,7 @@ OpPrintingFlags &enableDebugInfo(bool enable = true, bool prettyForm = false); /// Always print operations in the generic form. - OpPrintingFlags &printGenericOpForm(); + OpPrintingFlags &printGenericOpForm(bool enable = true); /// Skip printing regions. OpPrintingFlags &skipRegions(bool skip = true); diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -163,6 +163,13 @@ } bool shouldVerifyPasses() const { return verifyPassesFlag; } + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &verifyRoundtrip(bool verify) { + verifyRoundtripFlag = verify; + return *this; + } + bool shouldVerifyRoundtrip() const { return verifyRoundtripFlag; } + protected: /// Allow operation with no registered dialects. /// This option is for convenience during testing only and discouraged in @@ -212,6 +219,9 @@ /// Run the verifier after each transformation pass. bool verifyPassesFlag = true; + + /// Verify that the input IR round-trips perfectly. + bool verifyRoundtripFlag = false; }; /// This defines the function type used to setup the pass manager. This can be diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -219,8 +219,8 @@ } /// Always print operations in the generic form. -OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { - printGenericOpFormFlag = true; +OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) { + printGenericOpFormFlag = enable; return *this; } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -139,6 +139,11 @@ cl::desc("Run the verifier after each transformation pass"), cl::location(verifyPassesFlag), cl::init(true)); + static cl::opt verifyRoundtrip( + "verify-roundtrip", + cl::desc("Round-trip the IR after parsing and ensure it succeeds"), + cl::location(verifyRoundtripFlag), cl::init(false)); + static cl::list passPlugins( "load-pass-plugin", cl::desc("Load passes from plugin library")); /// Set the callback to load a pass plugin. @@ -213,6 +218,104 @@ }); } +LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) { + DialectRegistry registry; + registry.insert(); + ctx.appendDialectRegistry(registry); + + // Set up the input file. + std::string errorMessage; + std::unique_ptr file = openInputFile(irdlFile, &errorMessage); + if (!file) { + emitError(UnknownLoc::get(&ctx)) << errorMessage; + return failure(); + } + + // Give the buffer to the source manager. + // This will be picked up by the parser. + SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + + SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx); + + // Parse the input file. + OwningOpRef module(parseSourceFile(sourceMgr, &ctx)); + + // Load IRDL dialects. + return irdl::loadDialects(module.get()); +} + +// Return success if the module can correctly round-trip. This intended to test +// that the custom printers/parsers are complete. +static LogicalResult doVerifyRoundTrip(Operation *op, + const MlirOptMainConfig &config, + bool useBytecode) { + // We use a new context to avoid resource handle renaming issue in the diff. + MLIRContext roundtripContext; + OwningOpRef roundtripModule; + roundtripContext.appendDialectRegistry( + op->getContext()->getDialectRegistry()); + if (op->getContext()->allowsUnregisteredDialects()) + roundtripContext.allowUnregisteredDialects(); + StringRef irdlFile = config.getIrdlFile(); + if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext))) + return failure(); + + // Print a first time with custom format (or bytecode) and parse it back to + // the roundtripModule. + { + std::string buffer; + llvm::raw_string_ostream ostream(buffer); + if (useBytecode) { + if (failed(writeBytecodeToFile(op, ostream))) { + op->emitOpError() << "failed to write bytecode, cannot verify round-trip.\n"; + return failure(); + } + } else { + op->print(ostream, + OpPrintingFlags().printGenericOpForm(false).enableDebugInfo()); + } + FallbackAsmResourceMap fallbackResourceMap; + ParserConfig parseConfig(&roundtripContext, /*verifyAfterParse=*/true, + &fallbackResourceMap); + roundtripModule = + parseSourceString(ostream.str(), parseConfig); + if (!roundtripModule) { + op->emitOpError() << "failed to parse bytecode back, cannot verify round-trip.\n"; + return failure(); + } + } + + // Print in the generic form for the reference module and the round-tripped + // one and compare the outputs. + std::string reference, roundtrip; + { + llvm::raw_string_ostream ostreamref(reference); + op->print(ostreamref, + OpPrintingFlags().printGenericOpForm().enableDebugInfo()); + llvm::raw_string_ostream ostreamrndtrip(roundtrip); + roundtripModule.get()->print( + ostreamrndtrip, + OpPrintingFlags().printGenericOpForm().enableDebugInfo()); + } + if (reference != roundtrip) { + // TODO implement a diff. + return op->emitOpError() << "roundTrip testing roundtripped module differs from reference:\n<<<<<>>>>roundtripped\n"; + } + + return success(); +} + +static LogicalResult doVerifyRoundTrip(Operation *op, + const MlirOptMainConfig &config) { + // Textual round-trip isn't fully robust at the moment (for example implicit + // terminator are losing location informations). + + return doVerifyRoundTrip(op, config, /*useBytecode=*/true); +} + /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// @@ -247,10 +350,16 @@ TimingScope parserTiming = timing.nest("Parser"); OwningOpRef op = parseSourceFileForTool( sourceMgr, parseConfig, !config.shouldUseExplicitModule()); - context->enableMultithreading(wasThreadingEnabled); + parserTiming.stop(); if (!op) return failure(); - parserTiming.stop(); + + // Perform round-trip verification if requested + if (config.shouldVerifyRoundtrip() && + failed(doVerifyRoundTrip(op.get(), config))) + return failure(); + + context->enableMultithreading(wasThreadingEnabled); // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit); @@ -286,33 +395,6 @@ return success(); } -LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) { - DialectRegistry registry; - registry.insert(); - ctx.appendDialectRegistry(registry); - - // Set up the input file. - std::string errorMessage; - std::unique_ptr file = openInputFile(irdlFile, &errorMessage); - if (!file) { - emitError(UnknownLoc::get(&ctx)) << errorMessage; - return failure(); - } - - // Give the buffer to the source manager. - // This will be picked up by the parser. - SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); - - SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx); - - // Parse the input file. - OwningOpRef module(parseSourceFile(sourceMgr, &ctx)); - - // Load IRDL dialects. - return irdl::loadDialects(module.get()); -} - /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. static LogicalResult processBuffer(raw_ostream &os, diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -65,7 +65,6 @@ tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir] tools = [ - 'mlir-opt', 'mlir-tblgen', 'mlir-translate', 'mlir-lsp-server', @@ -125,6 +124,11 @@ ToolSubst('%PYTHON', python_executable, unresolved='ignore'), ]) +if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ: + tools.extend([ + ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'), + ]) + llvm_config.add_tool_substitutions(tools, tool_dirs)