diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -50,6 +50,7 @@ DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); DEFINE_C_API_STRUCT(MlirOperation, void); +DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); @@ -228,6 +229,42 @@ void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute *attributes); +/*============================================================================*/ +/* Op Printing flags API. */ +/* While many of these are simple settings that could be represented in a */ +/* struct, they are wrapped in a heap allocated object and accessed via */ +/* functions to maximize the possibility of compatibility over time. */ +/*============================================================================*/ + +/** Creates new printing flags with defaults, intended for customization. + * Must be freed with a call to mlirOpPrintingFlagsDestroy(). */ +MlirOpPrintingFlags mlirOpPrintingFlagsCreate(); + +/** Destroys printing flags created with mlirOpPrintingFlagsCreate. */ +void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags); + +/** Enables the elision of large elements attributes by printing a lexically + * valid but otherwise meaningless form instead of the element data. The + * `largeElementLimit` is used to configure what is considered to be a "large" + * ElementsAttr by providing an upper limit to the number of elements. */ +void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, + intptr_t largeElementLimit); + +/** Enable printing of debug information. If 'prettyForm' is set to true, + * debug information is printed in a more readable 'pretty' form. Note: The + * IR generated with 'prettyForm' is not parsable. */ +void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, + int prettyForm); + +/** Always print operations in the generic form. */ +void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags); + +/** Use local scope when printing the operation. This allows for using the + * printer in a more localized and thread-safe setting, but may not + * necessarily be identical to what the IR will look like when dumping + * the full module. */ +void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); + /*============================================================================*/ /* Operation API. */ /*============================================================================*/ @@ -298,6 +335,11 @@ void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, void *userData); +/** Same as mlirOperationPrint but accepts flags controlling the printing + * behavior. */ +void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, + MlirStringCallback callback, void *userData); + /** Prints an operation to stderr. */ void mlirOperationDump(MlirOperation op); diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -24,6 +24,7 @@ DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) +DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags); DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -562,10 +562,10 @@ OpPrintingFlags(); OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {} - /// Enable the elision of large elements attributes, by printing a '...' - /// instead of the element data. Note: The IR generated with this option is - /// not parsable. `largeElementLimit` is used to configure what is considered - /// to be a "large" ElementsAttr by providing an upper limit to the number of + /// Enables the elision of large elements attributes by printing a lexically + /// valid but otherwise meaningless form instead of the element data. The + /// `largeElementLimit` is used to configure what is considered to be a + /// "large" ElementsAttr by providing an upper limit to the number of /// elements. OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16); diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -277,6 +277,15 @@ } void checkValid(); + /// Implements the bound 'print' method and helps with others. + void print(pybind11::object fileObject, bool binary, + llvm::Optional largeElementsLimit, bool enableDebugInfo, + bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); + pybind11::object getAsm(bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -64,12 +64,44 @@ static const char kContextGetFileLocationDocstring[] = R"(Gets a Location representing a file, line and column)"; +static const char kOperationPrintDocstring[] = + R"(Prints the assembly form of the operation to a file like object. + +Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + large_elements_limit: Whether to elide elements attributes above this + number of elements. Defaults to None (no limit). + enable_debug_info: Whether to print debug/location information. Defaults + to False. + pretty_debug_info: Whether to format debug information for easier reading + by a human (warning: the result is unparseable). + print_generic_op_form: Whether to print the generic assembly forms of all + ops. Defaults to False. + use_local_Scope: Whether to print in a way that is more optimized for + multi-threaded access but may not be consistent with how the overall + module prints. +)"; + +static const char kOperationGetAsmDocstring[] = + R"(Gets the assembly form of the operation with all options available. + +Args: + binary: Whether to return a bytes (True) or str (False) object. Defaults to + False. + ... others ...: See the print() method for common keyword arguments for + configuring the printout. +Returns: + Either a bytes or str object, depending on the setting of the 'binary' + argument. +)"; + static const char kOperationStrDunderDocstring[] = - R"(Prints the assembly form of the operation with default options. + R"(Gets the assembly form of the operation with default options. If more advanced control over the assembly formatting or I/O options is needed, -use the dedicated print method, which supports keyword arguments to customize -behavior. +use the dedicated print or get_asm method, which supports keyword arguments to +customize behavior. )"; static const char kDumpDocstring[] = @@ -118,6 +150,35 @@ } }; +/// Accumulates int a python file-like object, either writing text (default) +/// or binary. +class PyFileAccumulator { +public: + PyFileAccumulator(py::object fileObject, bool binary) + : pyWriteFunction(fileObject.attr("write")), binary(binary) {} + + void *getUserData() { return this; } + + MlirStringCallback getCallback() { + return [](const char *part, intptr_t size, void *userData) { + py::gil_scoped_acquire(); + PyFileAccumulator *accum = static_cast(userData); + if (accum->binary) { + // Note: Still has to copy and not avoidable with this API. + py::bytes pyBytes(part, size); + accum->pyWriteFunction(pyBytes); + } else { + py::str pyStr(part, size); // Decodes as UTF-8 by default. + accum->pyWriteFunction(pyStr); + } + }; + } + +private: + py::object pyWriteFunction; + bool binary; +}; + /// Accumulates into a python string from a method that is expected to make /// one (no more, no less) call to the callback (asserts internally on /// violation). @@ -712,6 +773,48 @@ } } +void PyOperation::print(py::object fileObject, bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope) { + checkValid(); + if (fileObject.is_none()) + fileObject = py::module::import("sys").attr("stdout"); + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (largeElementsLimit) + mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); + if (enableDebugInfo) + mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); + if (printGenericOpForm) + mlirOpPrintingFlagsPrintGenericOpForm(flags); + + PyFileAccumulator accum(fileObject, binary); + py::gil_scoped_release(); + mlirOperationPrintWithFlags(get(), flags, accum.getCallback(), + accum.getUserData()); + mlirOpPrintingFlagsDestroy(flags); +} + +py::object PyOperation::getAsm(bool binary, + llvm::Optional largeElementsLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope) { + py::object fileObject; + if (binary) { + fileObject = py::module::import("io").attr("BytesIO")(); + } else { + fileObject = py::module::import("io").attr("StringIO")(); + } + print(fileObject, /*binary=*/binary, + /*largeElementsLimit=*/largeElementsLimit, + /*enableDebugInfo=*/enableDebugInfo, + /*prettyDebugInfo=*/prettyDebugInfo, + /*printGenericOpForm=*/printGenericOpForm, + /*useLocalScope=*/useLocalScope); + + return fileObject.attr("getvalue")(); +} + //------------------------------------------------------------------------------ // PyAttribute. //------------------------------------------------------------------------------ @@ -745,7 +848,8 @@ /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed /// to accommodate other levels unless core MLIR changes. -template class PyConcreteValue : public PyValue { +template +class PyConcreteValue : public PyValue { public: // Derived classes must define statics for: // IsAFunctionTy isaFunction @@ -1969,13 +2073,30 @@ .def( "__str__", [](PyOperation &self) { - self.checkValid(); - PyPrintAccumulator printAccum; - mlirOperationPrint(self.get(), printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); + return self.getAsm(/*binary=*/false, + /*largeElementsLimit=*/llvm::None, + /*enableDebugInfo=*/false, + /*prettyDebugInfo=*/false, + /*printGenericOpForm=*/false, + /*useLocalScope=*/false); }, - "Returns the assembly form of the operation."); + "Returns the assembly form of the operation.") + .def("print", &PyOperation::print, + // Careful: Lots of arguments must match up with print method. + py::arg("file") = py::none(), py::arg("binary") = false, + py::arg("large_elements_limit") = py::none(), + py::arg("enable_debug_info") = false, + py::arg("pretty_debug_info") = false, + py::arg("print_generic_op_form") = false, + py::arg("use_local_scope") = false, kOperationPrintDocstring) + .def("get_asm", &PyOperation::getAsm, + // Careful: Lots of arguments must match up with get_asm method. + py::arg("binary") = false, + py::arg("large_elements_limit") = py::none(), + py::arg("enable_debug_info") = false, + py::arg("pretty_debug_info") = false, + py::arg("print_generic_op_form") = false, + py::arg("use_local_scope") = false, kOperationGetAsmDocstring); // Mapping of PyRegion. py::class_(m, "Region") diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -74,6 +74,36 @@ return wrap(unwrap(dialect)->getNamespace()); } +/* ========================================================================== */ +/* Printing flags API. */ +/* ========================================================================== */ + +MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { + return wrap(new OpPrintingFlags()); +} + +void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { + delete unwrap(flags); +} + +void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, + intptr_t largeElementLimit) { + unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); +} + +void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, + int prettyForm) { + unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); +} + +void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { + unwrap(flags)->printGenericOpForm(); +} + +void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { + unwrap(flags)->useLocalScope(); +} + /* ========================================================================== */ /* Location API. */ /* ========================================================================== */ @@ -282,6 +312,13 @@ stream.flush(); } +void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, + MlirStringCallback callback, void *userData) { + detail::CallbackOstream stream(callback, userData); + unwrap(op)->print(stream, *unwrap(flags)); + stream.flush(); +} + void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } /* ========================================================================== */ diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -1,6 +1,7 @@ # RUN: %PYTHON %s | FileCheck %s import gc +import io import itertools import mlir @@ -248,3 +249,44 @@ run(testOperationResultList) + + +# CHECK-LABEL: TEST: testOperationPrint +def testOperationPrint(): + ctx = mlir.ir.Context() + module = ctx.parse_module(r""" + func @f1(%arg0: i32) -> i32 { + %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + return %arg0 : i32 + } + """) + + # Test print to stdout. + # CHECK: return %arg0 : i32 + module.operation.print() + + # Test print to text file. + f = io.StringIO() + # CHECK: + # CHECK: return %arg0 : i32 + module.operation.print(file=f) + str_value = f.getvalue() + print(str_value.__class__) + print(f.getvalue()) + + # Test print to binary file. + f = io.BytesIO() + # CHECK: + # CHECK: return %arg0 : i32 + module.operation.print(file=f, binary=True) + bytes_value = f.getvalue() + print(bytes_value.__class__) + print(bytes_value) + + # Test get_asm with options. + # CHECK: value = opaque<"", "0xDEADBEEF"> : tensor<4xi32> + # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7 + module.operation.print(large_elements_limit=2, enable_debug_info=True, + pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True) + +run(testOperationPrint) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -10,9 +10,9 @@ /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s */ +#include "mlir-c/IR.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Diagnostics.h" -#include "mlir-c/IR.h" #include "mlir-c/Registration.h" #include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardDialect.h" @@ -319,6 +319,25 @@ fprintf(stderr, "Removed attr is null: %d\n", mlirAttributeIsNull( mlirOperationGetAttributeByName(operation, "custom_attr"))); + + // Add a large attribute to verify printing flags. + int64_t eltsShape[] = {4}; + int32_t eltsData[] = {1, 2, 3, 4}; + mlirOperationSetAttributeByName( + operation, "elts", + mlirDenseElementsAttrInt32Get( + mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4, + eltsData)); + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2); + mlirOpPrintingFlagsPrintGenericOpForm(flags); + mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0); + mlirOpPrintingFlagsUseLocalScope(flags); + fprintf(stderr, "Op print with all flags: "); + mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL); + fprintf(stderr, "\n"); + + mlirOpPrintingFlagsDestroy(flags); } /// Creates an operation with a region containing multiple blocks with @@ -991,6 +1010,7 @@ // CHECK: Remove attr: 1 // CHECK: Remove attr again: 0 // CHECK: Removed attr is null: 1 + // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown) // clang-format on mlirModuleDestroy(moduleOp);