diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -774,7 +774,7 @@ OpPrintingFlags &enableDebugInfo(bool prettyForm = false); /// Always print operations in the generic form. - OpPrintingFlags &printGenericOpForm(); + OpPrintingFlags &printGenericOpForm(bool generic = true); /// Do not verify the operation when using custom operation printers. OpPrintingFlags &assumeVerified(); diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -51,26 +51,22 @@ /// dialects from the global registry in the MLIRContext. This option is /// deprecated and will be removed soon. /// - emitBytecode will generate bytecode output instead of text. -LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, - std::unique_ptr buffer, - const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext = false, - bool emitBytecode = false); +LogicalResult MlirOptMain( + llvm::raw_ostream &outputStream, std::unique_ptr buffer, + const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, + bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext = false, + bool emitBytecode = false, bool verifyRoundTrip = false); /// Support a callback to setup the pass manager. /// - passManagerSetupFn is the callback invoked to setup the pass manager to /// apply on the loaded IR. -LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, - std::unique_ptr buffer, - PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext = false, - bool emitBytecode = false); +LogicalResult MlirOptMain( + llvm::raw_ostream &outputStream, std::unique_ptr buffer, + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext = false, + bool emitBytecode = false, bool verifyRoundTrip = false); /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -217,8 +217,8 @@ } /// Always print operations in the generic form. -OpPrintingFlags &OpPrintingFlags::printGenericOpForm() { - printGenericOpFormFlag = true; +OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool generic) { + printGenericOpFormFlag = generic; return *this; } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -39,6 +39,52 @@ using namespace mlir; using namespace llvm; +// Return success if the module can correctly round-trip. This intended to test +// that the custom printers/parsers are complete. +static LogicalResult doVerifyRoundTrip(ModuleOp moduleOp, ParserConfig &config, + bool useBytecode) { + OwningOpRef roundtripModule; + // Print a first time with custom format (or bytecode) and parse it back to + // the roundtripModule. + { + std::string buffer; + llvm::raw_string_ostream ostream(buffer); + moduleOp.print( + ostream, OpPrintingFlags().printGenericOpForm(false).enableDebugInfo()); + roundtripModule = parseSourceString(ostream.str(), config); + if (!roundtripModule) + return failure(); + } + + // Print in the generic form for the reference module and the round-tripped + // one and compare the outputs. + std::string reference, roundtrip; + { + llvm::raw_string_ostream ostreamref(reference); + moduleOp.print(ostreamref, + OpPrintingFlags().printGenericOpForm().enableDebugInfo()); + llvm::raw_string_ostream ostreamrndtrip(roundtrip); + roundtripModule->print( + ostreamrndtrip, + OpPrintingFlags().printGenericOpForm().enableDebugInfo()); + } + if (reference != roundtrip) { + // TODO implement a diff. + return moduleOp.emitOpError() << "roundTrip testing failed!\n<<<<<<\n" + << reference << "\n>>>>>\n" + << roundtrip; + } + + return success(); +} + +static LogicalResult doVerifyRoundTrip(ModuleOp moduleOp, + ParserConfig &config) { + return success( + succeeded(doVerifyRoundTrip(moduleOp, config, /*useBytecode=*/true)) && + succeeded(doVerifyRoundTrip(moduleOp, config, /*useBytecode=*/false))); +} + /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// @@ -46,8 +92,8 @@ /// passes, then prints the output. /// static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, - bool verifyPasses, SourceMgr &sourceMgr, - MLIRContext *context, + bool verifyPasses, bool verifyRoundTrip, + SourceMgr &sourceMgr, MLIRContext *context, PassPipelineFn passManagerSetupFn, bool emitBytecode) { DefaultTimingManager tm; @@ -73,6 +119,11 @@ // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); OwningOpRef module(parseSourceFile(sourceMgr, config)); + + // Perform round-trip verification if requested + if (verifyRoundTrip && failed(doVerifyRoundTrip(*module, config))) + return failure(); + context->enableMultithreading(wasThreadingEnabled); if (!module) return failure(); @@ -103,8 +154,9 @@ processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode, PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, llvm::ThreadPool *threadPool) { + bool emitBytecode, bool verifyRoundTrip, + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -127,8 +179,9 @@ // otherwise just perform the actions without worrying about it. if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); - return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passManagerSetupFn, emitBytecode); + return performActions(os, verifyDiagnostics, verifyPasses, verifyRoundTrip, + sourceMgr, &context, passManagerSetupFn, + emitBytecode); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -136,8 +189,8 @@ // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. - (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passManagerSetupFn, emitBytecode); + (void)performActions(os, verifyDiagnostics, verifyPasses, verifyRoundTrip, + sourceMgr, &context, passManagerSetupFn, emitBytecode); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -151,7 +204,7 @@ bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode) { + bool emitBytecode, bool verifyRoundtrip) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -170,7 +223,7 @@ raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, emitBytecode, + preloadDialectsInContext, emitBytecode,verifyRoundtrip, passManagerSetupFn, registry, threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, @@ -184,7 +237,7 @@ bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode) { + bool emitBytecode, bool verifyRoundTrip) { auto passManagerSetupFn = [&](PassManager &pm) { auto errorHandler = [&](const Twine &msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; @@ -195,7 +248,7 @@ return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode); + emitBytecode, verifyRoundTrip); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -220,6 +273,12 @@ "expected-* lines on the corresponding line"), cl::init(false)); + static cl::opt verifyRoundTrip( + "verify-roundtrip", + cl::desc( + "Check that the input IR can rountrip before executing the pipeline"), + cl::init(false)); + static cl::opt verifyPasses( "verify-each", cl::desc("Run the verifier after each transformation pass"), @@ -282,7 +341,7 @@ if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode))) + emitBytecode, verifyRoundTrip))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file %s | FileCheck %s +// RUN: mlir-opt -verify-roundtrip -allow-unregistered-dialect -split-input-file %s | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s // Check that the attributes for the affine operations are round-tripped. diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -verify-roundtrip | FileCheck %s // CHECK-LABEL: func @ops // CHECK-SAME: (%[[I32:.*]]: i32, %[[FLOAT:.*]]: f32, %[[I8PTR1:.*]]: !llvm.ptr, %[[I8PTR2:.*]]: !llvm.ptr, %[[BOOL:.*]]: i1, %[[VI8PTR1:.*]]: !llvm.vec<2 x ptr>) diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s -// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s +// RUN: mlir-opt -verify-roundtrip %s | FileCheck %s // TODO: Re-enable LLVM lowering test. // diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -verify-diagnostics %s -verify-roundtrip | mlir-opt | FileCheck %s // CHECK-LABEL: func @avx512_mask_rndscale func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -1,6 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s -// Verify the printed output can be parsed. -// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -verify-roundtrip | FileCheck %s // Verify the generic form can be parsed. // RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-recursive-types | FileCheck %s +// RUN: mlir-opt %s -test-recursive-types -verify-roundtrip | FileCheck %s // CHECK: !testrec = !test.test_rec> diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -808,7 +808,6 @@ } void AffineScopeOp::print(OpAsmPrinter &p) { - p << "test.affine_scope "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); }