diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -517,6 +517,21 @@ are used inside optional groups are allowed only if all captured parameters are also optional. +#### Default-Valued Parameters + +Optional parameters can be given default values by setting `defaultValue`, a +string of the C++ default value, or by using `DefaultValuedParameter`. If a +value for the parameter was no encountered during parsing, it is set to this +default value. If a parameter is equal to its default value, it is not printed, +so the equality operator must be implemented for the type. + +For example: + +``` +let parameters = (ins DefaultValuedParameter<"Optional", "5">:$a) +let assemblyFormat = "`<` $a `>`"; +``` + ### Assembly Format Directives Attribute and type assembly formats have the following directives: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -3150,6 +3150,12 @@ // must be default constructible and be contextually convertible to `bool`. // Any `Optional` and any attribute type satisfies these requirements. bit isOptional = 0; + // Provide a default value for the parameter. Parameters with default values + // are considered optional. If a value was not parsed for the parameter, it + // will be set to the default value. Parameters equal to their default values + // are elided when printing (the equality comparator has to be implemented for + // the underlying C++ type). + string defaultValue = ?; } class AttrParameter : AttrOrTypeParameter; @@ -3200,6 +3206,13 @@ let isOptional = 1; } +// A parameter with a default value. +class DefaultValuedParameter : + AttrOrTypeParameter { + let isOptional = 1; + let defaultValue = value; +} + // This is a special parameter used for AttrDefs that represents a `mlir::Type` // that is also used as the value `Type` of the attribute. Only one parameter // of the attribute may be of this type. diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -52,6 +52,9 @@ explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index) : def(def), index(index) {} + /// Returns true if the parameter is anonymous (has no name). + bool isAnonymous() const; + /// Get the parameter name. StringRef getName() const; @@ -85,6 +88,9 @@ /// Returns true if the parameter is optional. bool isOptional() const; + /// Get the default value of the parameter if it has one. + Optional getDefaultValue() const; + /// Return the underlying def of this parameter. llvm::Init *getDef() const; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -187,6 +187,10 @@ return result; } +bool AttrOrTypeParameter::isAnonymous() const { + return !def->getArgName(index); +} + StringRef AttrOrTypeParameter::getName() const { return def->getArgName(index)->getValue(); } @@ -239,7 +243,13 @@ } bool AttrOrTypeParameter::isOptional() const { - return getDefValue("isOptional").getValueOr(false); + // Parameters with default values are automatically optional. + return getDefValue("isOptional").getValueOr(false) || + getDefaultValue().hasValue(); +} + +Optional AttrOrTypeParameter::getDefaultValue() const { + return getDefValue("defaultValue"); } llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -311,7 +311,7 @@ } def TestTypeOptionalGroupParams : Test_Type<"TestTypeOptionalGroupParams"> { - let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + let parameters = (ins DefaultValuedParameter<"mlir::Optional", "10">:$a, OptionalParameter<"mlir::Optional">:$b); let mnemonic = "optional_group_params"; let assemblyFormat = "`<` (`(` params^ `)`) : (`x`)? `>`"; diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -388,6 +388,9 @@ let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`"; } +// TYPE: ::mlir::Type TestFType::parse(::mlir::AsmParser &parser) { +// TYPE: _result_a.getValueOr(int()) + // TYPE: void TestFType::print(::mlir::AsmPrinter &printer) const { // TYPE if (getA()) { // TYPE printer << ' '; @@ -405,7 +408,7 @@ // TYPE: if (parser.parseComma()) // TYPE: return {}; -// TYPE: if (getA()) +// TYPE: if ((getA())) // TYPE: printer.printStrippedAttrOrType(getA()); // TYPE: printer << ", "; // TYPE: printer.printStrippedAttrOrType(getB()); @@ -425,7 +428,7 @@ // TYPE: return {}; // TYPE: void TestHType::print(::mlir::AsmPrinter &printer) const { -// TYPE: if (getA()) { +// TYPE: if ((getA())) { // TYPE: printer << "a = "; // TYPE: printer.printStrippedAttrOrType(getA()); // TYPE: printer << ", "; @@ -468,9 +471,9 @@ // TYPE: return {}; // TYPE: void TestJType::print(::mlir::AsmPrinter &printer) const { -// TYPE: if (getB()) { +// TYPE: if ((getB())) { // TYPE: printer << "("; -// TYPE: if (getB()) +// TYPE: if ((getB())) // TYPE: printer.printStrippedAttrOrType(getB()); // TYPE: printer << ")"; // TYPE: } else { @@ -483,3 +486,15 @@ let mnemonic = "type_j"; let assemblyFormat = "(`(` $b^ `)`) : (`x`)? $a"; } + +// TYPE: ::mlir::Type TestKType::parse(::mlir::AsmParser &parser) { +// TYPE: _result_a.getValueOr(10) + +// TYPE: void TestKType::print(::mlir::AsmPrinter &printer) const { +// TYPE: if ((getA() && !(getA() == 10))) + +def TypeI : TestType<"TestK"> { + let parameters = (ins DefaultValuedParameter<"int", "10">:$a); + let mnemonic = "type_k"; + let assemblyFormat = "$a"; +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -179,6 +179,11 @@ : def(def), params(def.getParameters()), defCls(def.getCppClassName()), valueType(isa(def) ? "Attribute" : "Type"), defType(isa(def) ? "Attr" : "Type") { + // Check that all parameters have names. + for (const AttrOrTypeParameter ¶m : def.getParameters()) + if (param.isAnonymous()) + llvm::PrintFatalError("all parameters must have a name"); + // If a storage class is needed, create one. if (def.getNumParameters() > 0) storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true); diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -56,6 +56,17 @@ /// Returns the name of the parameter. StringRef getName() const { return param.getName(); } + /// Generate the code to check whether the parameter should be printed. + auto genPrintGuard(FmtContext &ctx) const { + return [&](raw_ostream &os) -> raw_ostream & { + ctx.withSelf(getParameterAccessorName(getName()) + "()"); + os << tgfmt("($_self", &ctx); + if (mlir::Optional defaultValue = getParam().getDefaultValue()) + os << tgfmt(" && !($_self == $0)", &ctx, *defaultValue); + return os << ")"; + }; + } + private: bool shouldBeQualifiedFlag = false; AttrOrTypeParameter param; @@ -65,6 +76,12 @@ static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); } static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); } +/// raw_ostream doesn't have an overload for stream functors. Declare one here. +template +raw_ostream &operator<<(raw_ostream &os, StreamFunctor &&fcn) { + return fcn(os); +} + /// Base class for a directive that contains references to multiple variables. template class ParamsDirectiveBase : public DirectiveElementBase { @@ -276,11 +293,13 @@ def.getCppClassName()); } for (const AttrOrTypeParameter ¶m : params) { - if (param.isOptional()) - os << formatv(",\n _result_{0}.getValueOr({1}())", param.getName(), - param.getCppStorageType()); - else + if (param.isOptional()) { + Twine defaultCtor = param.getCppStorageType() + "()"; + os << formatv(",\n _result_{0}.getValueOr({1})", param.getName(), + param.getDefaultValue().getValueOr(defaultCtor.str())); + } else { os << formatv(",\n *_result_{0}", param.getName()); + } } os << ");"; } @@ -634,9 +653,10 @@ const AttrOrTypeParameter ¶m = el->getParam(); ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); - // Guard the printer on the presence of optional parameters. + // Guard the printer on the presence of optional parameters and that they + // aren't equal to their default values (if they have one). if (el->isOptional() && !skipGuard) { - os << tgfmt("if ($_self) {\n", &ctx); + os << "if (" << el->genPrintGuard(ctx) << ") {\n"; os.indent(); } @@ -657,23 +677,27 @@ os.unindent() << "}\n"; } +/// Generate code to guard printing on the presence of any optional parameters. +template +static void guardOnAny(FmtContext &ctx, MethodBody &os, + ParameterRange &¶ms) { + os << "if ("; + llvm::interleave( + params, os, + [&](ParameterElement *param) { os << param->genPrintGuard(ctx); }, + " || "); + os << ") {\n"; + os.indent(); +} + void DefFormat::genCommaSeparatedPrinter( ArrayRef params, FmtContext &ctx, MethodBody &os, function_ref extra) { // Emit a space if necessary, but only if the struct is present. if (shouldEmitSpace || !lastWasPunctuation) { bool allOptional = llvm::all_of(params, paramIsOptional); - if (allOptional) { - os << "if ("; - llvm::interleave( - params, os, - [&](ParameterElement *param) { - os << getParameterAccessorName(param->getName()) << "()"; - }, - " || "); - os << ") {\n"; - os.indent(); - } + if (allOptional) + guardOnAny(ctx, os, params); os << tgfmt("$_printer << ' ';\n", &ctx); if (allOptional) os.unindent() << "}\n"; @@ -684,8 +708,7 @@ os.indent() << "bool _firstPrinted = true;\n"; for (ParameterElement *param : params) { if (param->isOptional()) { - os << tgfmt("if ($_self()) {\n", - &ctx.withSelf(getParameterAccessorName(param->getName()))); + os << "if (" << param->genPrintGuard(ctx) << ") {\n"; os.indent(); } os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx); @@ -716,26 +739,14 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, MethodBody &os) { - // Emit the check on whether the group should be printed. - const auto guardOn = [&](auto params) { - os << "if ("; - llvm::interleave( - params, os, - [&](ParameterElement *el) { - os << getParameterAccessorName(el->getName()) << "()"; - }, - " || "); - os << ") {\n"; - os.indent(); - }; FormatElement *anchor = el->getAnchor(); if (auto *param = dyn_cast(anchor)) { - guardOn(llvm::makeArrayRef(param)); + guardOnAny(ctx, os, llvm::makeArrayRef(param)); } else if (auto *params = dyn_cast(anchor)) { - guardOn(params->getParams()); + guardOnAny(ctx, os, params->getParams()); } else { - auto *strct = dyn_cast(anchor); - guardOn(strct->getParams()); + auto *strct = cast(anchor); + guardOnAny(ctx, os, strct->getParams()); } // Generate the printer for the contained elements. {