diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -518,6 +518,30 @@ are used inside optional groups are allowed only if all captured parameters are also optional. +#### Default-Valued Parameters + +Optional parameters can be given default values by setting `defaultValue`, a +string of the C++ default value, or by using `DefaultValuedParameter`. If a +value for the parameter was not encountered during parsing, it is set to this +default value. If a parameter is equal to its default value, it is not printed. +The `comparator` field of the parameter is used, but if one is not specified, +the equality operator is used. + +For example: + +``` +let parameters = (ins DefaultValuedParameter<"Optional", "5">:$a) +let mnemonic = "default_valued"; +let assemblyFormat = "(`<` $a^ `>`)?"; +``` + +Which will look like: + +``` +!test.default_valued // a = 5 +!test.default_valued<10> // a = 10 +``` + ### Assembly Format Directives Attribute and type assembly formats have the following directives: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -3133,6 +3133,12 @@ // must be default constructible and be contextually convertible to `bool`. // Any `Optional` and any attribute type satisfies these requirements. bit isOptional = 0; + // Provide a default value for the parameter. Parameters with default values + // are considered optional. If a value was not parsed for the parameter, it + // will be set to the default value. Parameters equal to their default values + // are elided when printing. Equality is checked using the `comparator` field + // or otherwise the default C++ equals operator. + string defaultValue = ?; } class AttrParameter : AttrOrTypeParameter; @@ -3183,6 +3189,13 @@ let isOptional = 1; } +// A parameter with a default value. +class DefaultValuedParameter : + AttrOrTypeParameter { + let isOptional = 1; + let defaultValue = value; +} + // This is a special parameter used for AttrDefs that represents a `mlir::Type` // that is also used as the value `Type` of the attribute. Only one parameter // of the attribute may be of this type. diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -52,6 +52,9 @@ explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index) : def(def), index(index) {} + /// Returns true if the parameter is anonymous (has no name). + bool isAnonymous() const; + /// Get the parameter name. StringRef getName() const; @@ -85,6 +88,9 @@ /// Returns true if the parameter is optional. bool isOptional() const; + /// Get the default value of the parameter if it has one. + Optional getDefaultValue() const; + /// Return the underlying def of this parameter. llvm::Init *getDef() const; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -187,6 +187,10 @@ return result; } +bool AttrOrTypeParameter::isAnonymous() const { + return !def->getArgName(index); +} + StringRef AttrOrTypeParameter::getName() const { return def->getArgName(index)->getValue(); } @@ -239,7 +243,13 @@ } bool AttrOrTypeParameter::isOptional() const { - return getDefValue("isOptional").getValueOr(false); + // Parameters with default values are automatically optional. + return getDefValue("isOptional").getValueOr(false) || + getDefaultValue().hasValue(); +} + +Optional AttrOrTypeParameter::getDefaultValue() const { + return getDefValue("defaultValue"); } llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -311,7 +311,7 @@ } def TestTypeOptionalGroupParams : Test_Type<"TestTypeOptionalGroupParams"> { - let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + let parameters = (ins DefaultValuedParameter<"mlir::Optional", "10">:$a, OptionalParameter<"mlir::Optional">:$b); let mnemonic = "optional_group_params"; let assemblyFormat = "`<` (`(` params^ `)`) : (`x`)? `>`"; @@ -330,4 +330,28 @@ let assemblyFormat = "`<` ` ` $a `\\n` `(` `)` `` `(` `)` $b `>`"; } +class DefaultValuedAPFloat + : DefaultValuedParameter<"llvm::Optional", + "llvm::Optional(" # value # ")"> { + let comparator = "$_lhs->bitwiseIsEqual(*$_rhs)"; + let parser = [{ [&]() -> mlir::FailureOr> { + mlir::FloatAttr attr; + auto result = $_parser.parseOptionalAttribute(attr); + if (result.hasValue() && mlir::succeeded(*result)) + return {attr.getValue()}; + if (!result.hasValue()) + return llvm::Optional(); + return mlir::failure(); + }() }]; + let printer = "$_printer << *$_self"; +} + +def TestTypeAPFloat : Test_Type<"TestTypeAPFloat"> { + let parameters = (ins + DefaultValuedAPFloat<"APFloat::getZero(APFloat::IEEEdouble())">:$a + ); + let mnemonic = "ap_float"; + let assemblyFormat = "`<` $a `>`"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -44,6 +44,9 @@ // CHECK: !test.optional_group_struct<(a = 10, b = 5)> // CHECK: !test.spaces< 5 // CHECK-NEXT: ()() 6> +// CHECK: !test.ap_float<5.000000e+00> +// CHECK: !test.ap_float<> + func private @test_roundtrip_default_parsers_struct( !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> ) -> ( @@ -70,5 +73,7 @@ !test.optional_group_struct, !test.optional_group_struct<(b = 5)>, !test.optional_group_struct<(b = 5, a = 10)>, - !test.spaces<5 ()() 6> + !test.spaces<5 ()() 6>, + !test.ap_float<5.0>, + !test.ap_float<> ) diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -389,10 +389,13 @@ let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`"; } +// TYPE: ::mlir::Type TestFType::parse(::mlir::AsmParser &odsParser) { +// TYPE: _result_a.getValueOr(int()) + // TYPE: void TestFType::print(::mlir::AsmPrinter &odsPrinter) const { // TYPE if (getA()) { -// TYPE printer << ' '; -// TYPE printer.printStrippedAttrOrType(getA()); +// TYPE odsPrinter << ' '; +// TYPE odsPrinter.printStrippedAttrOrType(getA()); def TypeD : TestType<"TestF"> { let parameters = (ins OptionalParameter<"int">:$a); let mnemonic = "type_f"; @@ -406,7 +409,7 @@ // TYPE: if (odsParser.parseComma()) // TYPE: return {}; -// TYPE: if (getA()) +// TYPE: if ((getA())) // TYPE: odsPrinter.printStrippedAttrOrType(getA()); // TYPE: odsPrinter << ", "; // TYPE: odsPrinter.printStrippedAttrOrType(getB()); @@ -426,7 +429,7 @@ // TYPE: return {}; // TYPE: void TestHType::print(::mlir::AsmPrinter &odsPrinter) const { -// TYPE: if (getA()) { +// TYPE: if ((getA())) { // TYPE: odsPrinter << "a = "; // TYPE: odsPrinter.printStrippedAttrOrType(getA()); // TYPE: odsPrinter << ", "; @@ -469,9 +472,9 @@ // TYPE: return {}; // TYPE: void TestJType::print(::mlir::AsmPrinter &odsPrinter) const { -// TYPE: if (getB()) { +// TYPE: if ((getB())) { // TYPE: odsPrinter << "("; -// TYPE: if (getB()) +// TYPE: if ((getB())) // TYPE: odsPrinter.printStrippedAttrOrType(getB()); // TYPE: odsPrinter << ")"; // TYPE: } else { @@ -484,3 +487,15 @@ let mnemonic = "type_j"; let assemblyFormat = "(`(` $b^ `)`) : (`x`)? $a"; } + +// TYPE: ::mlir::Type TestKType::parse(::mlir::AsmParser &odsParser) { +// TYPE: _result_a.getValueOr(10) + +// TYPE: void TestKType::print(::mlir::AsmPrinter &odsPrinter) const { +// TYPE: if ((getA() && !(getA() == 10))) + +def TypeI : TestType<"TestK"> { + let parameters = (ins DefaultValuedParameter<"int", "10">:$a); + let mnemonic = "type_k"; + let assemblyFormat = "$a"; +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -179,6 +179,11 @@ : def(def), params(def.getParameters()), defCls(def.getCppClassName()), valueType(isa(def) ? "Attribute" : "Type"), defType(isa(def) ? "Attr" : "Type") { + // Check that all parameters have names. + for (const AttrOrTypeParameter ¶m : def.getParameters()) + if (param.isAnonymous()) + llvm::PrintFatalError("all parameters must have a name"); + // If a storage class is needed, create one. if (def.getNumParameters() > 0) storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true); diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -56,6 +56,24 @@ /// Returns the name of the parameter. StringRef getName() const { return param.getName(); } + /// Generate the code to check whether the parameter should be printed. + auto genPrintGuard(FmtContext &ctx) const { + return [&](raw_ostream &os) -> raw_ostream & { + std::string self = getParameterAccessorName(getName()) + "()"; + ctx.withSelf(self); + os << tgfmt("($_self", &ctx); + if (llvm::Optional defaultValue = + getParam().getDefaultValue()) { + // Use the `comparator` field if it exists, else the equality operator. + StringRef comparator = + getParam().getComparator().getValueOr("$_lhs == $_rhs"); + ctx.addSubst("_lhs", self).addSubst("_rhs", *defaultValue); + os << " && !(" << tgfmt(comparator, &ctx) << ")"; + } + return os << ")"; + }; + } + private: bool shouldBeQualifiedFlag = false; AttrOrTypeParameter param; @@ -65,6 +83,12 @@ static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); } static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); } +/// raw_ostream doesn't have an overload for stream functors. Declare one here. +template +static raw_ostream &operator<<(raw_ostream &os, StreamFunctor &&fcn) { + return fcn(os); +} + /// Base class for a directive that contains references to multiple variables. template class ParamsDirectiveBase : public DirectiveElementBase { @@ -274,11 +298,13 @@ def.getCppClassName()); } for (const AttrOrTypeParameter ¶m : params) { - if (param.isOptional()) - os << formatv(",\n _result_{0}.getValueOr({1}())", param.getName(), - param.getCppStorageType()); - else + if (param.isOptional()) { + std::string defaultCtor = (param.getCppStorageType() + "()").str(); + os << formatv(",\n _result_{0}.getValueOr({1})", param.getName(), + param.getDefaultValue().getValueOr(defaultCtor)); + } else { os << formatv(",\n *_result_{0}", param.getName()); + } } os << ");"; } @@ -642,9 +668,10 @@ const AttrOrTypeParameter ¶m = el->getParam(); ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); - // Guard the printer on the presence of optional parameters. + // Guard the printer on the presence of optional parameters and that they + // aren't equal to their default values (if they have one). if (el->isOptional() && !skipGuard) { - os << tgfmt("if ($_self) {\n", &ctx); + os << "if (" << el->genPrintGuard(ctx) << ") {\n"; os.indent(); } @@ -665,23 +692,27 @@ os.unindent() << "}\n"; } +/// Generate code to guard printing on the presence of any optional parameters. +template +static void guardOnAny(FmtContext &ctx, MethodBody &os, + ParameterRange &¶ms) { + os << "if ("; + llvm::interleave( + params, os, + [&](ParameterElement *param) { os << param->genPrintGuard(ctx); }, + " || "); + os << ") {\n"; + os.indent(); +} + void DefFormat::genCommaSeparatedPrinter( ArrayRef params, FmtContext &ctx, MethodBody &os, function_ref extra) { // Emit a space if necessary, but only if the struct is present. if (shouldEmitSpace || !lastWasPunctuation) { bool allOptional = llvm::all_of(params, paramIsOptional); - if (allOptional) { - os << "if ("; - llvm::interleave( - params, os, - [&](ParameterElement *param) { - os << getParameterAccessorName(param->getName()) << "()"; - }, - " || "); - os << ") {\n"; - os.indent(); - } + if (allOptional) + guardOnAny(ctx, os, params); os << tgfmt("$_printer << ' ';\n", &ctx); if (allOptional) os.unindent() << "}\n"; @@ -692,8 +723,7 @@ os.indent() << "bool _firstPrinted = true;\n"; for (ParameterElement *param : params) { if (param->isOptional()) { - os << tgfmt("if ($_self()) {\n", - &ctx.withSelf(getParameterAccessorName(param->getName()))); + os << "if (" << param->genPrintGuard(ctx) << ") {\n"; os.indent(); } os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx); @@ -724,26 +754,14 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, MethodBody &os) { - // Emit the check on whether the group should be printed. - const auto guardOn = [&](auto params) { - os << "if ("; - llvm::interleave( - params, os, - [&](ParameterElement *el) { - os << getParameterAccessorName(el->getName()) << "()"; - }, - " || "); - os << ") {\n"; - os.indent(); - }; FormatElement *anchor = el->getAnchor(); if (auto *param = dyn_cast(anchor)) { - guardOn(llvm::makeArrayRef(param)); + guardOnAny(ctx, os, llvm::makeArrayRef(param)); } else if (auto *params = dyn_cast(anchor)) { - guardOn(params->getParams()); + guardOnAny(ctx, os, params->getParams()); } else { - auto *strct = dyn_cast(anchor); - guardOn(strct->getParams()); + auto *strct = cast(anchor); + guardOnAny(ctx, os, strct->getParams()); } // Generate the printer for the contained elements. {