diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -709,6 +709,12 @@ /// elements. OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16); + /// Enables the out of line printing of large elements attributes. + /// The `largeElementLimit` is used to configure what is considered to be a + /// "large" ElementsAttr by providing an upper limit to the number of + /// elements. + OpPrintingFlags &outlineLargeElementsAttrs(int64_t largeElementLimit = 16); + /// Enable printing of debug information. If 'prettyForm' is set to true, /// debug information is printed in a more readable 'pretty' form. Note: The /// IR generated with 'prettyForm' is not parsable. @@ -729,6 +735,12 @@ /// Return the size limit for printing large ElementsAttr. Optional getLargeElementsAttrLimit() const; + /// Return if the given ElementsAttr should be printed out of line. + bool shouldOutlineElementsAttr(ElementsAttr attr) const; + + /// Return the size limit for printing large ElementsAttr out of line. + Optional getLargeElementsAttrOutlineLimit() const; + /// Return if debug information should be printed. bool shouldPrintDebugInfo() const; @@ -745,6 +757,9 @@ /// Elide large elements attributes if the number of elements is larger than /// the upper limit. Optional elementsAttrElementLimit; + /// Outline large elements attributes if the number of elements is larger than + /// the upper limit. + Optional elementsAttrOutlineLimit; /// Print debug information. bool printDebugInfoFlag : 1; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -127,6 +127,11 @@ llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " "more elements than the given upper limit")}; + llvm::cl::opt outlineElementsAttrIfLarger{ + "mlir-outline-elementsattrs-if-larger", + llvm::cl::desc("Print ElementsAttrs out of line that have " + "more elements than the given upper limit")}; + llvm::cl::opt printDebugInfoOpt{ "mlir-print-debuginfo", llvm::cl::init(false), llvm::cl::desc("Print debug info in MLIR output")}; @@ -166,6 +171,8 @@ return; if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; + if (clOptions->outlineElementsAttrIfLarger.getNumOccurrences()) + elementsAttrOutlineLimit = clOptions->outlineElementsAttrIfLarger; printDebugInfoFlag = clOptions->printDebugInfoOpt; printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; printGenericOpFormFlag = clOptions->printGenericOpFormOpt; @@ -216,6 +223,19 @@ return elementsAttrElementLimit; } +/// Return if the given ElementsAttr should be outlined. +bool OpPrintingFlags::shouldOutlineElementsAttr(ElementsAttr attr) const { + return elementsAttrOutlineLimit.hasValue() && + (*elementsAttrOutlineLimit == 0 || + (*elementsAttrOutlineLimit < int64_t(attr.getNumElements()) && + !attr.isa())); +} + +/// Return the size limit for printing large ElementsAttr out of line. +Optional OpPrintingFlags::getLargeElementsAttrOutlineLimit() const { + return elementsAttrOutlineLimit; +} + /// Return if debug information should be printed. bool OpPrintingFlags::shouldPrintDebugInfo() const { return printDebugInfoFlag; @@ -324,10 +344,11 @@ /// Visit the given attribute to see if it has an alias. `canBeDeferred` is /// set to true if the originator of this attribute can resolve the alias /// after parsing has completed (e.g. in the case of operation locations). - void visit(Attribute attr, bool canBeDeferred = false); + void visit(Attribute attr, const OpPrintingFlags &printerFlags, + bool canBeDeferred = false); /// Visit the given type to see if it has an alias. - void visit(Type type); + void visit(Type type, const OpPrintingFlags &printerFlags); private: /// Try to generate an alias for the provided symbol. If an alias is @@ -377,7 +398,7 @@ void print(Operation *op) { // Visit the operation location. if (printerFlags.shouldPrintDebugInfo()) - initializer.visit(op->getLoc(), /*canBeDeferred=*/true); + initializer.visit(op->getLoc(), printerFlags, /*canBeDeferred=*/true); // If requested, always print the generic form. if (!printerFlags.shouldPrintGenericOpForm()) { @@ -428,7 +449,8 @@ // Visit the argument location. if (printerFlags.shouldPrintDebugInfo()) // TODO: Allow deferring argument locations. - initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); + initializer.visit(arg.getLoc(), printerFlags, + /*canBeDeferred=*/false); } } @@ -463,23 +485,25 @@ // Visit the argument location. if (printerFlags.shouldPrintDebugInfo()) // TODO: Allow deferring argument locations. - initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); + initializer.visit(arg.getLoc(), printerFlags, /*canBeDeferred=*/false); } /// Consider the given type to be printed for an alias. - void printType(Type type) override { initializer.visit(type); } + void printType(Type type) override { initializer.visit(type, printerFlags); } /// Consider the given attribute to be printed for an alias. - void printAttribute(Attribute attr) override { initializer.visit(attr); } + void printAttribute(Attribute attr) override { + initializer.visit(attr, printerFlags); + } void printAttributeWithoutType(Attribute attr) override { printAttribute(attr); } LogicalResult printAlias(Attribute attr) override { - initializer.visit(attr); + initializer.visit(attr, printerFlags); return success(); } LogicalResult printAlias(Type type) override { - initializer.visit(type); + initializer.visit(type, printerFlags); return success(); } @@ -634,7 +658,9 @@ initializeAliases(aliasToType, typeToAlias); } -void AliasInitializer::visit(Attribute attr, bool canBeDeferred) { +void AliasInitializer::visit(Attribute attr, + const OpPrintingFlags &printerFlags, + bool canBeDeferred) { if (!visitedAttributes.insert(attr).second) { // If this attribute already has an alias and this instance can't be // deferred, make sure that the alias isn't deferred. @@ -649,15 +675,25 @@ deferrableAttributes.insert(attr); return; } + // Add an alias for large ElementsAttr's to print out of line. + if (auto elementsAttr = attr.dyn_cast()) { + if (printerFlags.shouldOutlineElementsAttr(elementsAttr)) { + aliasToAttr["cst"].push_back(attr); + if (canBeDeferred) + deferrableAttributes.insert(attr); + return; + } + } // Check for any sub elements. if (auto subElementInterface = attr.dyn_cast()) { - subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, - [&](Type type) { visit(type); }); + subElementInterface.walkSubElements( + [&](Attribute attr) { visit(attr, printerFlags); }, + [&](Type type) { visit(type, printerFlags); }); } } -void AliasInitializer::visit(Type type) { +void AliasInitializer::visit(Type type, const OpPrintingFlags &printerFlags) { if (!visitedTypes.insert(type).second) return; @@ -667,8 +703,9 @@ // Check for any sub elements. if (auto subElementInterface = type.dyn_cast()) { - subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, - [&](Type type) { visit(type); }); + subElementInterface.walkSubElements( + [&](Attribute attr) { visit(attr, printerFlags); }, + [&](Type type) { visit(type, printerFlags); }); } } diff --git a/mlir/test/IR/pretty-attributes.mlir b/mlir/test/IR/pretty-attributes.mlir --- a/mlir/test/IR/pretty-attributes.mlir +++ b/mlir/test/IR/pretty-attributes.mlir @@ -1,11 +1,18 @@ // RUN: mlir-opt %s -mlir-elide-elementsattrs-if-larger=2 | FileCheck %s +// RUN: mlir-opt %s -mlir-outline-elementsattrs-if-larger=2 | FileCheck %s --check-prefix=OUTLINE + // Ensure that the elided version is still parseable, although depending on // what has been elided, it may not be semantically meaningful. // In the typical case where what is being elided is a very large constant // tensor which passes don't look at directly, this isn't an issue. // RUN: mlir-opt %s -mlir-elide-elementsattrs-if-larger=2 | mlir-opt +// OUTLINE-DAG: #[[DCST:cst[0-9]+]] = dense<[1, 2, 3]> : tensor<3xi32> +// OUTLINE-DAG: #cst{{[0-9]+}} = sparse<{{\[}}[0, 0, 5]], -2.000000e+00> : vector<1x1x10xf16> +// OUTLINE-DAG: #cst{{[0-9]+}} = opaque<"_", "0xEBFE"> : tensor<100xf32> + // CHECK: opaque<"_", "0xDEADBEEF"> : tensor<3xi32> +// OUTLINE: "test.dense_attr"() {foo.dense_attr = #[[DCST]]} : () -> () "test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> () // CHECK: dense<[1, 2]> : tensor<2xi32>