diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -1103,7 +1103,7 @@ OpPrintingFlags &enableDebugInfo(bool enable = true, bool prettyForm = false); /// Always print operations in the generic form. - OpPrintingFlags &printGenericOpForm(); + OpPrintingFlags &printGenericOpForm(bool enable = true); /// Skip printing regions. OpPrintingFlags &skipRegions(bool skip = true); diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -163,6 +163,13 @@ } bool shouldVerifyPasses() const { return verifyPassesFlag; } + /// Set whether to run the verifier after each transformation pass. + MlirOptMainConfig &verifyRoundtrip(bool verify) { + verifyRoundtripFlag = verify; + return *this; + } + bool shouldVerifyRoundtrip() const { return verifyRoundtripFlag; } + protected: /// Allow operation with no registered dialects. /// This option is for convenience during testing only and discouraged in @@ -212,6 +219,9 @@ /// Run the verifier after each transformation pass. bool verifyPassesFlag = true; + + /// Verify that the input IR round-trips perfectly. + bool verifyRoundtripFlag = false; }; /// This defines the function type used to setup the pass manager. This can be diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -219,8 +219,8 @@ } /// Always print operations in the generic form. -OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { - printGenericOpFormFlag = true; +OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) { + printGenericOpFormFlag = enable; return *this; } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -139,6 +139,11 @@ cl::desc("Run the verifier after each transformation pass"), cl::location(verifyPassesFlag), cl::init(true)); + static cl::opt verifyRoundtrip( + "verify-roundtrip", + cl::desc("Round-trip the IR after parsing and ensure it succeeds"), + cl::location(verifyRoundtripFlag), cl::init(false)); + static cl::list passPlugins( "load-pass-plugin", cl::desc("Load passes from plugin library")); /// Set the callback to load a pass plugin. @@ -213,6 +218,67 @@ }); } +// Return success if the module can correctly round-trip. This intended to test +// that the custom printers/parsers are complete. +static LogicalResult doVerifyRoundTrip(Operation *op, + const MlirOptMainConfig &config, + bool useBytecode) { + OwningOpRef roundtripModule; + // Print a first time with custom format (or bytecode) and parse it back to + // the roundtripModule. + { + std::string buffer; + llvm::raw_string_ostream ostream(buffer); + if (useBytecode) { + if (failed(writeBytecodeToFile(op, ostream))) { + op->emitOpError() << "failed to write bytecode, cannot verify round-trip.\n"; + return failure(); + } + } else { + op->print(ostream, + OpPrintingFlags().printGenericOpForm(false).enableDebugInfo()); + } + FallbackAsmResourceMap fallbackResourceMap; + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true, + &fallbackResourceMap); + roundtripModule = + parseSourceString(ostream.str(), parseConfig); + if (!roundtripModule) { + op->emitOpError() << "failed to parse bytecode back, cannot verify round-trip.\n"; + return failure(); + } + } + + // Print in the generic form for the reference module and the round-tripped + // one and compare the outputs. + std::string reference, roundtrip; + { + llvm::raw_string_ostream ostreamref(reference); + op->print(ostreamref, + OpPrintingFlags().printGenericOpForm().enableDebugInfo()); + llvm::raw_string_ostream ostreamrndtrip(roundtrip); + roundtripModule.get()->print( + ostreamrndtrip, + OpPrintingFlags().printGenericOpForm().enableDebugInfo()); + } + if (reference != roundtrip) { + // TODO implement a diff. + return op->emitOpError() << "roundTrip testing failed!\n<<<<<<\n" + << reference << "\n>>>>>\n" + << roundtrip; + } + + return success(); +} + +static LogicalResult doVerifyRoundTrip(Operation *op, + const MlirOptMainConfig &config) { + // Textual round-trip isn't fully robust at the moment (for example implicit + // terminator are losing location informations). + + return doVerifyRoundTrip(op, config, /*useBytecode=*/true); +} + /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// @@ -247,10 +313,16 @@ TimingScope parserTiming = timing.nest("Parser"); OwningOpRef op = parseSourceFileForTool( sourceMgr, parseConfig, !config.shouldUseExplicitModule()); - context->enableMultithreading(wasThreadingEnabled); + parserTiming.stop(); if (!op) return failure(); - parserTiming.stop(); + + // Perform round-trip verification if requested + if (config.shouldVerifyRoundtrip() && + failed(doVerifyRoundTrip(op.get(), config))) + return failure(); + + context->enableMultithreading(wasThreadingEnabled); // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit); diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -65,7 +65,6 @@ tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir] tools = [ - 'mlir-opt', 'mlir-tblgen', 'mlir-translate', 'mlir-lsp-server', @@ -125,6 +124,15 @@ ToolSubst('%PYTHON', python_executable, unresolved='ignore'), ]) +if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ: + tools.extend([ + ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'), + ]) +else: + tools.extend([ + 'mlir-opt', + ]) + llvm_config.add_tool_substitutions(tools, tool_dirs)