diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -393,6 +393,10 @@ MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); +/// Do not verify the operation when using custom operation printers. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); + //===----------------------------------------------------------------------===// // Operation API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -995,17 +995,6 @@ if (fileObject.is_none()) fileObject = py::module::import("sys").attr("stdout"); - if (!assumeVerified && !printGenericOpForm && - !mlirOperationVerify(operation)) { - std::string message("// Verification failed, printing generic form\n"); - if (binary) { - fileObject.attr("write")(py::bytes(message)); - } else { - fileObject.attr("write")(py::str(message)); - } - printGenericOpForm = true; - } - MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (largeElementsLimit) mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); @@ -1015,6 +1004,8 @@ mlirOpPrintingFlagsPrintGenericOpForm(flags); if (useLocalScope) mlirOpPrintingFlagsUseLocalScope(flags); + if (assumeVerified) + mlirOpPrintingFlagsAssumeVerified(flags); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -140,6 +140,10 @@ unwrap(flags)->useLocalScope(); } +void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { + unwrap(flags)->assumeVerified(); +} + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1340,45 +1340,9 @@ } // namespace detail } // namespace mlir -/// Verifies the operation and switches to generic op printing if verification -/// fails. We need to do this because custom print functions may fail for -/// invalid ops. -static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, - OpPrintingFlags printerFlags) { - if (printerFlags.shouldPrintGenericOpForm() || - printerFlags.shouldAssumeVerified()) - return printerFlags; - - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Verifying operation: " - << op->getName() << "\n"); - - // Ignore errors emitted by the verifier. We check the thread id to avoid - // consuming other threads' errors. - auto parentThreadId = llvm::get_threadid(); - ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { - if (parentThreadId == llvm::get_threadid()) { - LLVM_DEBUG({ - diag.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - return success(); - } - return failure(); - }); - if (failed(verify(op))) { - LLVM_DEBUG(llvm::dbgs() - << DEBUG_TYPE << ": '" << op->getName() - << "' failed to verify and will be printed in generic form\n"); - printerFlags.printGenericOpForm(); - } - - return printerFlags; -} - AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, LocationMap *locationMap) - : impl(std::make_unique( - op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {} + : impl(std::make_unique(op, printerFlags, locationMap)) {} AsmState::~AsmState() = default; const OpPrintingFlags &AsmState::getPrinterFlags() const { @@ -2816,6 +2780,35 @@ } void OperationPrinter::printOperation(Operation *op) { + // Switch to the generic printer if verification fails. We need to do this + // because custom print functions may fail for invalid ops. + if (!printerFlags.shouldPrintGenericOpForm() && + !printerFlags.shouldAssumeVerified()) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Verifying operation: " + << op->getName() << "\n"); + // Ignore errors emitted by the verifier. We check the thread id to avoid + // consuming other threads' errors. + auto parentThreadId = llvm::get_threadid(); + auto handler = [&](Diagnostic &diag) { + if (parentThreadId != llvm::get_threadid()) + return failure(); + LLVM_DEBUG({ + diag.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + return success(); + }; + ScopedDiagnosticHandler diagHandler(op->getContext(), handler); + if (failed(verify(op))) { + LLVM_DEBUG(llvm::dbgs() + << DEBUG_TYPE << ": '" << op->getName() + << "' failed to verify and will be printed in generic form\n"); + os << "// Verification failed, printing generic form"; + printNewline(); + printerFlags.printGenericOpForm(); + } + } + if (size_t numResults = op->getNumResults()) { auto printResultGroup = [&](size_t resultNo, size_t resultCount) { printValueID(op->getResult(resultNo), /*printResultNo=*/false); diff --git a/mlir/test/IR/print-ir-invalid.mlir b/mlir/test/IR/print-ir-invalid.mlir --- a/mlir/test/IR/print-ir-invalid.mlir +++ b/mlir/test/IR/print-ir-invalid.mlir @@ -8,6 +8,7 @@ // The operation is invalid because the body does not have a terminator, print // the generic form. // CHECK: Invalid operation: +// CHECK-NEXT: // Verification failed, printing generic form // CHECK-NEXT: "func.func"() ({ // CHECK-NEXT: ^bb0: // CHECK-NEXT: })