diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -949,6 +949,10 @@ // The full description of this attribute. string description = ""; + + // If set, the attribute will not be printed if it is equal to the default + // value. + bit elidePrintingDefaultValue = 0; } // An attribute of a specific dialect. @@ -962,7 +966,8 @@ // Attribute modifier definition // Decorates an attribute to have an (unvalidated) default value if not present. -class DefaultValuedAttr : +class DefaultValuedAttr : Attr { // Construct this attribute with the input attribute and change only // the default value. @@ -975,11 +980,13 @@ let valueType = attr.valueType; let baseAttr = attr; + let elidePrintingDefaultValue = elideDefaultPrint; } // Decorates an optional attribute to have an (unvalidated) default value // return by ODS generated accessors if not present. -class DefaultValuedOptionalAttr : +class DefaultValuedOptionalAttr : Attr { // Construct this attribute with the input attribute and change only // the default value. @@ -993,6 +1000,7 @@ let isOptional = 1; let baseAttr = attr; + let elidePrintingDefaultValue = elideDefaultPrint; } // Decorates an attribute as optional. The return type of the generated diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -116,6 +116,10 @@ // Returns the TableGen definition this Attribute was constructed from. const llvm::Record &getDef() const; + + // Returns true if this attribute should not be printed when the attribute + // value is equal to the default value. + bool elidePrintingDefaultValue() const; }; // Wrapper class providing helper methods for accessing MLIR constant attribute diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -134,6 +134,10 @@ const llvm::Record &Attribute::getDef() const { return *def; } +bool Attribute::elidePrintingDefaultValue() const { + return def->getValueAsBit("elidePrintingDefaultValue"); +} + ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { assert(def->isSubClassOf("ConstantAttr") && "must be subclass of TableGen 'ConstantAttr' class"); diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -715,3 +715,23 @@ } : () -> () return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test DefaultValuedAttr Printing +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @default_value_printing +func.func @default_value_printing(%arg0 : i32) { + // The attribute SHOULD be printed regardless of equality with the default (0) + // CHECK: test.default_value_print {value_with_default = 0 : i32} %arg0 + "test.default_value_print"(%arg0) {"value_with_default" = 0 : i32} : (i32) -> () + // The attribute SHOULD NOT be printed because elidePrintingDefaultValue=1 + // CHECK: test.default_value_no_print %arg0 + "test.default_value_no_print"(%arg0) {"value_with_default" = 0 : i32} : (i32) -> () + // The attribute SHOULD be printed because is is not equal to the default + // CHECK: test.default_value_no_print {value_with_default = 1 : i32} %arg0 + "test.default_value_no_print"(%arg0) {"value_with_default" = 1 : i32} : (i32) -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2891,6 +2891,22 @@ def : Pat<(TestDefaultStrAttrNoValueOp $value), (TestDefaultStrAttrHasValueOp ConstantStrAttr)>; +//===----------------------------------------------------------------------===// +// Test Ops with Default-Valued Attributes and Differing Print Settings +//===----------------------------------------------------------------------===// + +def TestDefaultAttrPrintOp : TEST_Op<"default_value_print"> { + let arguments = (ins DefaultValuedAttr:$value_with_default, + I32:$operand); + let assemblyFormat = "attr-dict $operand"; +} + +def TestDefaultAttrNoPrintOp : TEST_Op<"default_value_no_print"> { + let arguments = (ins DefaultValuedAttr:$value_with_default, + I32:$operand); + let assemblyFormat = "attr-dict $operand"; +} + //===----------------------------------------------------------------------===// // Test Ops with effects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -369,7 +369,7 @@ StrAttr:$str_attr, BoolAttr:$bool_attr, SomeI32Enum:$enum_attr, - DefaultValuedAttr:$dv_i32_attr, + DefaultValuedAttr:$dv_i32_attr, DefaultValuedAttr:$dv_f64_attr, DefaultValuedStrAttr:$dv_str_attr, DefaultValuedAttr:$dv_bool_attr, diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -484,3 +484,7 @@ test.else_anchor(%b : !test.else_anchor<5>) {a = !test.else_anchor_struct} return } + +//===----------------------------------------------------------------------===// +// Check DefaultValuedAttr Printing +//===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1636,24 +1636,40 @@ /// Generate the printer for the 'attr-dict' directive. static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, MethodBody &body, bool withKeyword) { - body << " _odsPrinter.printOptionalAttrDict" - << (withKeyword ? "WithKeyword" : "") - << "((*this)->getAttrs(), /*elidedAttrs=*/{"; + body << " ::llvm::SmallVector<::llvm::StringRef> elidedAttrs;\n"; // Elide the variadic segment size attributes if necessary. if (!fmt.allOperands && op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) - body << "\"operand_segment_sizes\", "; + body << " elidedAttrs.push_back(\"operand_segment_sizes\");\n"; if (!fmt.allResultTypes && op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) - body << "\"result_segment_sizes\", "; - if (!fmt.inferredAttributes.empty()) { - for (const auto &attr : fmt.inferredAttributes) - body << "\"" << attr.getKey() << "\", "; - } - llvm::interleaveComma( - fmt.usedAttributes, body, - [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; }); - body << "});\n"; + body << " elidedAttrs.push_back(\"result_segment_sizes\");\n"; + for (const auto &attr : fmt.inferredAttributes) + body << " elidedAttrs.push_back(\"" << attr.getKey() << "\");\n"; + for (const auto &attr : fmt.usedAttributes) + body << " elidedAttrs.push_back(\"" << attr->name << "\");\n"; + // Add code to check attributes for equality with the default value + // for attributes with the elidePrintingDefaultValue bit set. + for (const auto &namedAttr : op.getAttributes()) { + const auto &attr = namedAttr.attr; + if (!attr.isDerivedAttr() && attr.hasDefaultValue() && + attr.elidePrintingDefaultValue()) { + const auto &name = namedAttr.name; + FmtContext fctx; + fctx.withBuilder("odsBuilder"); + std::string defaultValue = std::string( + tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + body << " {\n"; + body << " ::mlir::Builder odsBuilder(getContext());\n"; + body << " auto attr = " << op.getGetterName(name) << "Attr();\n"; + body << " if(attr && (attr == " << defaultValue << "))\n"; + body << " elidedAttrs.push_back(\"" << name << "\");\n"; + body << " }\n"; + } + } + body << " _odsPrinter.printOptionalAttrDict" + << (withKeyword ? "WithKeyword" : "") + << "((*this)->getAttrs(), elidedAttrs);\n"; } /// Generate the printer for a literal value. `shouldEmitSpace` is true if a