diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -640,23 +640,21 @@ parameter, parameters can override `convertFromStorage`, which by default is `"$_self"` (i.e., it attempts an implicit conversion from `cppStorageType`). -###### Optional Parameters +###### Optional and Default-Valued Parameters +An optional parameter can be omitted from the assembly format of an attribute or +a type. An optional parameter is omitted when it is equal to its default value. Optional parameters in the assembly format can be indicated by setting -`isOptional`. The C++ type of an optional parameter is required to satisfy the -following requirements: +`defaultValue`, a string of the C++ default value. If a value for the parameter +was not encountered during parsing, it is set to this default value. If a +parameter is equal to its default value, it is not printed. The `comparator` +field of the parameter is used, but if one is not specified, the equality +operator is used. -- is default-constructible -- is contextually convertible to `bool` -- only the default-constructed value is `false` - -The parameter parser should return the default-constructed value to indicate "no -value present". The printer will guard on the presence of a value to print the -parameter. - -If a value was not parsed for an optional parameter, then the parameter will be -set to its default-constructed C++ value. For example, `Optional` will be -set to `llvm::None` and `Attribute` will be set to `nullptr`. +When using `OptionalParameter`, the default value is set to the C++ +default-constructed value for the C++ storage type. For example, `Optional` +will be set to `llvm::None` and `Attribute` will be set to `nullptr`. The +presence of these parameters is tested by comparing them to their "null" values. Only optional parameters or directives that only capture optional parameters can be used in optional groups. An optional group is a set of elements optionally @@ -673,16 +671,9 @@ are used inside optional groups are allowed only if all captured parameters are also optional. -###### Default-Valued Parameters - -Optional parameters can be given default values by setting `defaultValue`, a -string of the C++ default value, or by using `DefaultValuedParameter`. If a -value for the parameter was not encountered during parsing, it is set to this -default value. If a parameter is equal to its default value, it is not printed. -The `comparator` field of the parameter is used, but if one is not specified, -the equality operator is used. - -For example: +An optional parameter can also be specified with `DefaultValuedParameter`, which +specifies that a parameter should be omitted when it is equal to some given +value. ```tablegen let parameters = (ins DefaultValuedParameter<"Optional", "5">:$a) diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -299,7 +299,7 @@ string cppAccessorType = !if(!empty(accessorType), type, accessorType); // The C++ storage type of of this parameter if it is a reference, e.g. // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`. - string cppStorageType = ?; + string cppStorageType = cppType; // The C++ code to convert from the storage type to the parameter type. string convertFromStorage = "$_self"; // One-line human-readable description of the argument. @@ -315,10 +315,6 @@ // operator of `AsmPrinter` as necessary to print your type. Or you can // provide a custom printer. string printer = ?; - // Mark a parameter as optional. The C++ type of parameters marked as optional - // must be default constructible and be contextually convertible to `bool`. - // Any `Optional` and any attribute type satisfies these requirements. - bit isOptional = 0; // Provide a default value for the parameter. Parameters with default values // are considered optional. If a value was not parsed for the parameter, it // will be set to the default value. Parameters equal to their default values @@ -374,13 +370,12 @@ // An optional parameter. class OptionalParameter : AttrOrTypeParameter { - let isOptional = 1; + let defaultValue = cppStorageType # "()"; } // A parameter with a default value. class DefaultValuedParameter : AttrOrTypeParameter { - let isOptional = 1; let defaultValue = value; } diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -277,9 +277,7 @@ } bool AttrOrTypeParameter::isOptional() const { - // Parameters with default values are automatically optional. - return getDefValue("isOptional").value_or(false) || - getDefaultValue(); + return getDefaultValue().has_value(); } Optional AttrOrTypeParameter::getDefaultValue() const { diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -286,19 +286,17 @@ } class DefaultValuedAPFloat - : DefaultValuedParameter<"llvm::Optional", - "llvm::Optional(" # value # ")"> { - let comparator = "$_lhs->bitwiseIsEqual(*$_rhs)"; - let parser = [{ [&]() -> mlir::FailureOr> { + : DefaultValuedParameter<"llvm::APFloat", "llvm::APFloat(" # value # ")"> { + let comparator = "$_lhs.bitwiseIsEqual($_rhs)"; + let parser = [{ [&]() -> mlir::FailureOr { mlir::FloatAttr attr; auto result = $_parser.parseOptionalAttribute(attr); if (result.has_value() && mlir::succeeded(*result)) - return {attr.getValue()}; + return attr.getValue(); if (!result.has_value()) - return llvm::Optional(); + return llvm::APFloat(}] # value # [{); return mlir::failure(); }() }]; - let printer = "$_printer << *$_self"; } def TestTypeAPFloat : Test_Type<"TestTypeAPFloat"> { diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -421,11 +421,11 @@ // TYPE: ::mlir::Type TestGType::parse(::mlir::AsmParser &odsParser) { // TYPE: if (::mlir::failed(_result_a)) // TYPE: return {}; -// TYPE: if (::mlir::succeeded(_result_a) && *_result_a) +// TYPE: if (::mlir::succeeded(_result_a) && !((*_result_a) == int())) // TYPE: if (odsParser.parseComma()) // TYPE: return {}; -// TYPE: if ((getA())) +// TYPE: if (!(getA() == int())) // TYPE: odsPrinter.printStrippedAttrOrType(getA()); // TYPE: odsPrinter << ", "; // TYPE: odsPrinter.printStrippedAttrOrType(getB()); @@ -445,7 +445,7 @@ // TYPE: return {}; // TYPE: void TestHType::print(::mlir::AsmPrinter &odsPrinter) const { -// TYPE: if ((getA())) { +// TYPE: if (!(getA() == int())) { // TYPE: odsPrinter << "a = "; // TYPE: odsPrinter.printStrippedAttrOrType(getA()); // TYPE: odsPrinter << ", "; @@ -488,9 +488,9 @@ // TYPE: return {}; // TYPE: void TestJType::print(::mlir::AsmPrinter &odsPrinter) const { -// TYPE: if ((getB())) { +// TYPE: if (!(getB() == int())) { // TYPE: odsPrinter << "("; -// TYPE: if ((getB())) +// TYPE: if (!(getB() == int())) // TYPE: odsPrinter.printStrippedAttrOrType(getB()); // TYPE: odsPrinter << ")"; // TYPE: } else { @@ -508,7 +508,7 @@ // TYPE: _result_a.value_or(10) // TYPE: void TestKType::print(::mlir::AsmPrinter &odsPrinter) const { -// TYPE: if ((getA() && !(getA() == 10))) +// TYPE: if (!(getA() == 10)) def TypeI : TestType<"TestK"> { let parameters = (ins DefaultValuedParameter<"int", "10">:$a); @@ -578,7 +578,7 @@ // TYPE: else // TYPE-LABEL: void TestOType::print -// TYPE: if (!((getA()))) +// TYPE: if (!(!(getA() == int()))) // TYPE: odsPrinter << ' ' << "?" // TYPE: else // TYPE: odsPrinter.printStrippedAttrOrType(getA()) @@ -598,7 +598,7 @@ // TYPE-NEXT: } // TYPE-LABEL: void TestPType::print -// TYPE: if (!((getA()) || (getB()))) +// TYPE: if (!(!(getA() == int()) || !(getB() == int()))) // TYPE-NEXT: odsPrinter << "?" def TypeN : TestType<"TestP"> { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -56,18 +56,19 @@ /// Returns the name of the parameter. StringRef getName() const { return param.getName(); } + /// Return the code to check whether the parameter is present. + auto genIsPresent(FmtContext &ctx, const Twine &self) const { + assert(isOptional() && "cannot guard on a mandatory parameter"); + std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str(); + ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr); + return tgfmt(getParam().getComparator(), &ctx); + } + /// Generate the code to check whether the parameter should be printed. MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const { + assert(isOptional() && "cannot guard on a mandatory parameter"); std::string self = param.getAccessorName() + "()"; - ctx.withSelf(self); - os << tgfmt("($_self", &ctx); - if (llvm::Optional defaultValue = getParam().getDefaultValue()) { - // Use the `comparator` field if it exists, else the equality operator. - std::string valueStr = tgfmt(*defaultValue, &ctx).str(); - ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr); - os << " && !(" << tgfmt(getParam().getComparator(), &ctx) << ")"; - } - return os << ")"; + return os << "!(" << genIsPresent(ctx, self) << ")"; } private: @@ -332,13 +333,9 @@ os << ",\n "; std::string paramSelfStr; llvm::raw_string_ostream selfOs(paramSelfStr); - if (param.isOptional()) { - selfOs << formatv("(_result_{0}.value_or(", param.getName()); - if (Optional defaultValue = param.getDefaultValue()) - selfOs << tgfmt(*defaultValue, &ctx); - else - selfOs << param.getCppStorageType() << "()"; - selfOs << "))"; + if (Optional defaultValue = param.getDefaultValue()) { + selfOs << formatv("(_result_{0}.value_or(", param.getName()) + << tgfmt(*defaultValue, &ctx) << "))"; } else { selfOs << formatv("(*_result_{0})", param.getName()); } @@ -447,8 +444,9 @@ ParameterElement *el = *std::prev(it); // Parse a comma if the last optional parameter had a value. if (el->isOptional()) { - os << formatv("if (::mlir::succeeded(_result_{0}) && *_result_{0}) {{\n", - el->getName()); + os << formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n", + el->getName(), + el->genIsPresent(ctx, "(*_result_" + el->getName() + ")")); os.indent(); } if (it <= lastReqIt) { @@ -522,18 +520,6 @@ } )"; - // Optional parameters in a struct must be parsed successfully if the - // keyword is present. - // - // {0}: Name of the parameter. - // {1}: Emit error string - const char *const checkOptionalParam = R"( - if (::mlir::succeeded(_result_{0}) && !*_result_{0}) {{ - {1}"expected a value for parameter '{0}'"); - return {{}; - } -)"; - // First iteration of the loop parsing an optional struct. const char *const optionalStructFirst = R"( ::llvm::StringRef _paramKey; @@ -558,11 +544,6 @@ " _seen_{0} = true;\n", param->getName()); genVariableParser(param, ctx, os.indent()); - if (param->isOptional()) { - os.getStream().printReindented(strfmt(checkOptionalParam, - param->getName(), - tgfmt(parserErrorStr, &ctx).str())); - } os.unindent() << "} else "; // Print the check for duplicate or unknown parameter. }