diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h --- a/mlir/include/mlir/Support/MlirOptMain.h +++ b/mlir/include/mlir/Support/MlirOptMain.h @@ -46,7 +46,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext = true); + bool preloadDialectsInContext = true, + bool verifyRoundTrip = false); /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -36,6 +36,37 @@ using namespace llvm; using llvm::SMLoc; +// Return success if the module can correctly round-trip. This intended to test +// that the custom printers/parsers are complete. +static LogicalResult doVerifyRoundTrip(ModuleOp moduleOp) { + OwningModuleRef roundtripModule; + // Print a first time and parse it back to the roundtripModule. + { + std::string buffer; + llvm::raw_string_ostream ostream(buffer); + moduleOp.print(ostream); + roundtripModule = parseSourceString(ostream.str(), moduleOp.getContext()); + if (!roundtripModule) + return failure(); + } + + // Print in the generic form for the reference module and the round-tripped + // one and compare the outputs. + std::string reference, roundtrip; + { + llvm::raw_string_ostream ostreamref(reference); + moduleOp.print(ostreamref, OpPrintingFlags().printGenericOpForm()); + llvm::raw_string_ostream ostreamrndtrip(roundtrip); + roundtripModule->print(ostreamrndtrip, + OpPrintingFlags().printGenericOpForm()); + } + if (reference != roundtrip) + // TODO implement a diff. + return moduleOp.emitOpError() << "roundTrip testing failed!"; + + return success(); +} + /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// @@ -43,8 +74,8 @@ /// passes, then prints the output. /// static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, - bool verifyPasses, SourceMgr &sourceMgr, - MLIRContext *context, + bool verifyPasses, bool verifyRoundTrip, + SourceMgr &sourceMgr, MLIRContext *context, const PassPipelineCLParser &passPipeline) { // Disable multi-threading when parsing the input file. This removes the // unnecessary/costly context synchronization when parsing. @@ -53,10 +84,14 @@ // Parse the input file and reset the context threading state. OwningModuleRef module(parseSourceFile(sourceMgr, context)); + context->enableMultithreading(wasThreadingEnabled); if (!module) return failure(); + if (verifyRoundTrip && failed(doVerifyRoundTrip(*module))) + return failure(); + // Apply any pass manager command line options. PassManager pm(context, verifyPasses); applyPassManagerCLOptions(pm); @@ -77,13 +112,12 @@ /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. -static LogicalResult processBuffer(raw_ostream &os, - std::unique_ptr ownedBuffer, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext, - const PassPipelineCLParser &passPipeline, - DialectRegistry ®istry) { +static LogicalResult +processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, + bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext, + bool verifyRoundTrip, const PassPipelineCLParser &passPipeline, + DialectRegistry ®istry) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -100,8 +134,8 @@ // otherwise just perform the actions without worrying about it. if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); - return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passPipeline); + return performActions(os, verifyDiagnostics, verifyPasses, verifyRoundTrip, + sourceMgr, &context, passPipeline); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -109,8 +143,8 @@ // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. - performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passPipeline); + performActions(os, verifyDiagnostics, verifyPasses, verifyRoundTrip, + sourceMgr, &context, passPipeline); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -123,7 +157,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext) { + bool preloadDialectsInContext, + bool verifyRoundTrip) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. if (splitInputFile) @@ -132,14 +167,15 @@ [&](std::unique_ptr chunkBuffer, raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passPipeline, - registry); + preloadDialectsInContext, verifyRoundTrip, + passPipeline, registry); }, outputStream); return processBuffer(outputStream, std::move(buffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passPipeline, registry); + preloadDialectsInContext, verifyRoundTrip, passPipeline, + registry); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -164,6 +200,12 @@ "expected-* lines on the corresponding line"), cl::init(false)); + static cl::opt verifyRoundTrip( + "verify-roundtrip", + cl::desc( + "Check that the input IR can rountrip before executing the pipeline"), + cl::init(false)); + static cl::opt verifyPasses( "verify-each", cl::desc("Run the verifier after each transformation pass"), @@ -222,7 +264,8 @@ if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext))) + allowUnregisteredDialects, preloadDialectsInContext, + verifyRoundTrip))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir --- a/mlir/test/Dialect/AVX512/roundtrip.mlir +++ b/mlir/test/Dialect/AVX512/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -verify-diagnostics %s -verify-roundtrip | FileCheck %s func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) -> (vector<16xf32>, vector<8xf64>) diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file %s | FileCheck %s +// RUN: mlir-opt -verify-roundtrip -allow-unregistered-dialect -split-input-file %s | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s // Check that the attributes for the affine operations are round-tripped. diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -verify-roundtrip | FileCheck %s // CHECK-LABEL: func @ops // CHECK-SAME: (%[[I32:.*]]: !llvm.i32, %[[FLOAT:.*]]: !llvm.float, %[[I8PTR1:.*]]: !llvm.ptr, %[[I8PTR2:.*]]: !llvm.ptr, %[[BOOL:.*]]: !llvm.i1) diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt -split-input-file %s | FileCheck %s -// | mlir-opt | FileCheck %s +// RUN: mlir-opt -verify-roundtrip -split-input-file %s | FileCheck %s // TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered. // diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -1,6 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s -// Verify the printed output can be parsed. -// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -verify-roundtrip | FileCheck %s // Verify the generic form can be parsed. // RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-recursive-types | FileCheck %s +// RUN: mlir-opt %s -test-recursive-types -verify-roundtrip | FileCheck %s // CHECK-LABEL: @roundtrip func @roundtrip() {