diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -404,6 +404,10 @@ MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); +/// Do not verify the operation when using custom operation printers. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); + //===----------------------------------------------------------------------===// // Operation API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1075,17 +1075,6 @@ if (fileObject.is_none()) fileObject = py::module::import("sys").attr("stdout"); - if (!assumeVerified && !printGenericOpForm && - !mlirOperationVerify(operation)) { - std::string message("// Verification failed, printing generic form\n"); - if (binary) { - fileObject.attr("write")(py::bytes(message)); - } else { - fileObject.attr("write")(py::str(message)); - } - printGenericOpForm = true; - } - MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (largeElementsLimit) mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); @@ -1096,6 +1085,8 @@ mlirOpPrintingFlagsPrintGenericOpForm(flags); if (useLocalScope) mlirOpPrintingFlagsUseLocalScope(flags); + if (assumeVerified) + mlirOpPrintingFlagsAssumeVerified(flags); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -141,6 +141,10 @@ unwrap(flags)->useLocalScope(); } +void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { + unwrap(flags)->assumeVerified(); +} + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -681,7 +681,6 @@ with Location.unknown(ctx): invalid_op = create_invalid_operation() # Verify that we fallback to the generic printer for safety. - # CHECK: // Verification failed, printing generic form # CHECK: "builtin.module"() ({ # CHECK: }) : () -> () print(invalid_op) @@ -698,7 +697,8 @@ with InsertionPoint(module.body): invalid_op = create_invalid_operation() # Verify that we fallback to the generic printer for safety. - # CHECK: // Verification failed, printing generic form + # CHECK: "builtin.module"() ({ + # CHECK: }) : () -> () print(module) @@ -709,7 +709,7 @@ with Location.unknown(ctx): invalid_op = create_invalid_operation() # Verify that we fallback to the generic printer for safety. - # CHECK: b'// Verification failed, printing generic form\n + # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n' print(invalid_op.get_asm(binary=True))