diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -1,9 +1,9 @@ # Defining Dialect Attributes and Types This document is a quickstart to defining dialect specific extensions to the -[attribute](../LangRef.md/#attributes) and [type](../LangRef.md/#type-system) systems in -MLIR. The main part of this tutorial focuses on defining types, but the -instructions are nearly identical for defining attributes. +[attribute](../LangRef.md/#attributes) and [type](../LangRef.md/#type-system) +systems in MLIR. The main part of this tutorial focuses on defining types, but +the instructions are nearly identical for defining attributes. See [MLIR specification](../LangRef.md) for more information about MLIR, the structure of the IR, operations, etc. @@ -24,18 +24,19 @@ So before defining the derived `Type`, it's important to know which of the two classes of `Type` we are defining: -Some types are _singleton_ in nature, meaning they have no parameters and only -ever have one instance, like the [`index` type](../Dialects/Builtin.md/#indextype). +Some types are *singleton* in nature, meaning they have no parameters and only +ever have one instance, like the +[`index` type](../Dialects/Builtin.md/#indextype). -Other types are _parametric_, and contain additional information that +Other types are *parametric*, and contain additional information that differentiates different instances of the same `Type`. For example the -[`integer` type](../Dialects/Builtin.md/#integertype) contains a bitwidth, with `i8` and -`i16` representing different instances of -[`integer` type](../Dialects/Builtin.md/#integertype). _Parametric_ may also contain a -mutable component, which can be used, for example, to construct self-referring -recursive types. The mutable component _cannot_ be used to differentiate -instances of a type class, so usually such types contain other parametric -components that serve to identify them. +[`integer` type](../Dialects/Builtin.md/#integertype) contains a bitwidth, with +`i8` and `i16` representing different instances of +[`integer` type](../Dialects/Builtin.md/#integertype). *Parametric* may also +contain a mutable component, which can be used, for example, to construct +self-referring recursive types. The mutable component *cannot* be used to +differentiate instances of a type class, so usually such types contain other +parametric components that serve to identify them. #### Singleton types @@ -389,12 +390,12 @@ `assemblyFormat` to declaratively describe custom parsers and printers. The assembly format consists of literals, variables, and directives. -* A literal is a keyword or valid punctuation enclosed in backticks, e.g. - `` `keyword` `` or `` `<` ``. -* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`, - which captures one attribute or type parameter. -* A directive is a keyword followed by an optional argument list that defines - special parser and printer behaviour. +* A literal is a keyword or valid punctuation enclosed in backticks, e.g. `` + `keyword` `` or `` `<` ``. +* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`, + which captures one attribute or type parameter. +* A directive is a keyword followed by an optional argument list that defines + special parser and printer behaviour. ```tablegen // An example type with an assembly format. @@ -412,8 +413,8 @@ } ``` -The declarative assembly format for `MyType` results in the following format -in the IR: +The declarative assembly format for `MyType` results in the following format in +the IR: ```mlir !my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)> @@ -421,15 +422,15 @@ ### Parameter Parsing and Printing -For many basic parameter types, no additional work is needed to define how -these parameters are parsed or printed. +For many basic parameter types, no additional work is needed to define how these +parameters are parsed or printed. -* The default printer for any parameter is `$_printer << $_self`, - where `$_self` is the C++ value of the parameter and `$_printer` is an - `AsmPrinter`. -* The default parser for a parameter is - `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type - of the parameter and `$_parser` is an `AsmParser`. +* The default printer for any parameter is `$_printer << $_self`, where + `$_self` is the C++ value of the parameter and `$_printer` is an + `AsmPrinter`. +* The default parser for a parameter is + `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type + of the parameter and `$_parser` is an `AsmParser`. Printing and parsing behaviour can be added to additional C++ types by overloading these functions or by defining a `parser` and `printer` in an ODS @@ -470,8 +471,8 @@ } ``` -A type using this parameter with the assembly format `` `<` $myParam `>` `` -will look as follows in the IR: +A type using this parameter with the assembly format `` `<` $myParam `>` `` will +look as follows in the IR: ```mlir !my_dialect.my_type<42 * 24> @@ -480,10 +481,42 @@ #### Non-POD Parameters Parameters that aren't plain-old-data (e.g. references) may need to define a -`cppStorageType` to contain the data until it is copied into the allocator. -For example, `StringRefParameter` uses `std::string` as its storage type, -whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers -for these parameters are expected to return `FailureOr<$cppStorageType>`. +`cppStorageType` to contain the data until it is copied into the allocator. For +example, `StringRefParameter` uses `std::string` as its storage type, whereas +`ArrayRefParameter` uses `SmallVector` as its storage type. The parsers for +these parameters are expected to return `FailureOr<$cppStorageType>`. + +#### Optional Parameters + +Optional parameters in the assembly format can be indicated by setting +`isOptional`. The C++ type of an optional parameter is required to satisfy the +following requirements: + +* is default-constructible +* is contextually convertible to `bool` +* only the default-constructed value is `false` + +The parameter parser should return the default-constructed value to indicate "no +value present". The printer will guard on the presence of a value to print the +parameter. + +If a value was not parsed for an optional parameter, then the parameter will be +set to its default-constructed C++ value. For example, `Optional` will be +set to `llvm::None` and `Attribute` will be set to `nullptr`. + +Only optional parameters or directives that only capture optional parameters can +be used in optional groups. An optional group is a set of elements optionally +printed based on the presence of an anchor. Suppose parameter `a` is an +`IntegerAttr`. + +``` +( `(` $a^ `)` ) : (`x`)? +``` + +In the above assembly format, if `a` is present (non-null), then it will be +printed as `(5 : i32)`. If it is not present, it will be `x`. Directives that +are used inside optional groups are allowed only if all captured parameters are +also optional. ### Assembly Format Directives @@ -497,9 +530,9 @@ #### `params` Directive -This directive is used to refer to all parameters of an attribute or type. -When used as a top-level directive, `params` generates a parser and printer for -a comma-separated list of the parameters. For example: +This directive is used to refer to all parameters of an attribute or type. When +used as a top-level directive, `params` generates a parser and printer for a +comma-separated list of the parameters. For example: ```tablegen def MyPairType : TypeDef { @@ -547,12 +580,16 @@ !my_dialect.outer_qual> ``` +If optional parameters are present, they are not printed in the parameter list +if they are not present. + #### `struct` Directive The `struct` directive accepts a list of variables to capture and will generate -a parser and printer for a comma-separated list of key-value pairs. The -variables are printed in the order they are specified in the argument list **but -can be parsed in any order**. For example: +a parser and printer for a comma-separated list of key-value pairs. If an +optional parameter is included in the `struct`, it can be elided. The variables +are printed in the order they are specified in the argument list **but can be +parsed in any order**. For example: ```tablegen def MyStructType : TypeDef { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -3120,15 +3120,19 @@ string summary = desc; // The format string for the asm syntax (documentation only). string syntax = ?; - // The default parameter parser is `::mlir::parseField($_parser)`, which - // returns `FailureOr`. Overload `parseField` to support parsing for your - // type. Or you can provide a customer printer. For attributes, "$_type" will - // be replaced with the required attribute type. + // The default parameter parser is `::mlir::FieldParser::parse($_parser)`, + // which returns `FailureOr`. Specialize `FieldParser` to support parsing + // for your type. Or you can provide a customer printer. For attributes, + // "$_type" will be replaced with the required attribute type. string parser = ?; // The default parameter printer is `$_printer << $_self`. Overload the stream // operator of `AsmPrinter` as necessary to print your type. Or you can // provide a custom printer. string printer = ?; + // Mark a parameter as optional. The C++ type of parameters marked as optional + // must be default constructible and be contextually convertible to `bool`. + // Any `Optional` and any attribute type satisfies these requirements. + bit isOptional = 0; } class AttrParameter : AttrOrTypeParameter; @@ -3173,6 +3177,12 @@ }]; } +// An optional parameter. +class OptionalParameter : + AttrOrTypeParameter { + let isOptional = 1; +} + // This is a special parameter used for AttrDefs that represents a `mlir::Type` // that is also used as the value `Type` of the attribute. Only one parameter // of the attribute may be of this type. diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -45,47 +45,50 @@ // AttrOrTypeParameter //===----------------------------------------------------------------------===// -// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to -// AttrOrTypeDefs to parameterize them. +/// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to +/// AttrOrTypeDefs to parameterize them. class AttrOrTypeParameter { public: explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index) : def(def), index(index) {} - // Get the parameter name. + /// Get the parameter name. StringRef getName() const; - // If specified, get the custom allocator code for this parameter. + /// If specified, get the custom allocator code for this parameter. Optional getAllocator() const; - // If specified, get the custom comparator code for this parameter. + /// If specified, get the custom comparator code for this parameter. Optional getComparator() const; - // Get the C++ type of this parameter. + /// Get the C++ type of this parameter. StringRef getCppType() const; - // Get the C++ accessor type of this parameter. + /// Get the C++ accessor type of this parameter. StringRef getCppAccessorType() const; - // Get the C++ storage type of this parameter. + /// Get the C++ storage type of this parameter. StringRef getCppStorageType() const; - // Get an optional C++ parameter parser. + /// Get an optional C++ parameter parser. Optional getParser() const; - // Get an optional C++ parameter printer. + /// Get an optional C++ parameter printer. Optional getPrinter() const; - // Get a description of this parameter for documentation purposes. + /// Get a description of this parameter for documentation purposes. Optional getSummary() const; - // Get the assembly syntax documentation. + /// Get the assembly syntax documentation. StringRef getSyntax() const; - // Return the underlying def of this parameter. - const llvm::Init *getDef() const; + /// Returns true if the parameter is optional. + bool isOptional() const; - // The parameter is pointer-comparable. + /// Return the underlying def of this parameter. + llvm::Init *getDef() const; + + /// The parameter is pointer-comparable. bool operator==(const AttrOrTypeParameter &other) const { return def == other.def && index == other.index; } @@ -94,6 +97,11 @@ } private: + /// A parameter can be either a string or a def. Get a potentially null value + /// from the def. + template + auto getDefValue(StringRef name) const; + /// The underlying tablegen parameter list this parameter is a part of. const llvm::DagInit *def; /// The index of the parameter within the parameter list (`def`). @@ -121,113 +129,113 @@ public: explicit AttrOrTypeDef(const llvm::Record *def); - // Get the dialect for which this def belongs. + /// Get the dialect for which this def belongs. Dialect getDialect() const; - // Returns the name of this AttrOrTypeDef record. + /// Returns the name of this AttrOrTypeDef record. StringRef getName() const; - // Query functions for the documentation of the def. + /// Query functions for the documentation of the def. bool hasDescription() const; StringRef getDescription() const; bool hasSummary() const; StringRef getSummary() const; - // Returns the name of the C++ class to generate. + /// Returns the name of the C++ class to generate. StringRef getCppClassName() const; - // Returns the name of the C++ base class to use when generating this def. + /// Returns the name of the C++ base class to use when generating this def. StringRef getCppBaseClassName() const; - // Returns the name of the storage class for this def. + /// Returns the name of the storage class for this def. StringRef getStorageClassName() const; - // Returns the C++ namespace for this def's storage class. + /// Returns the C++ namespace for this def's storage class. StringRef getStorageNamespace() const; - // Returns true if we should generate the storage class. + /// Returns true if we should generate the storage class. bool genStorageClass() const; - // Indicates whether or not to generate the storage class constructor. + /// Indicates whether or not to generate the storage class constructor. bool hasStorageCustomConstructor() const; /// Get the parameters of this attribute or type. ArrayRef getParameters() const { return parameters; } - // Return the number of parameters + /// Return the number of parameters unsigned getNumParameters() const; - // Return the keyword/mnemonic to use in the printer/parser methods if we are - // supposed to auto-generate them. + /// Return the keyword/mnemonic to use in the printer/parser methods if we are + /// supposed to auto-generate them. Optional getMnemonic() const; - // Returns the code to use as the types printer method. If not specified, - // return a non-value. Otherwise, return the contents of that code block. + /// Returns the code to use as the types printer method. If not specified, + /// return a non-value. Otherwise, return the contents of that code block. Optional getPrinterCode() const; - // Returns the code to use as the parser method. If not specified, returns - // None. Otherwise, returns the contents of that code block. + /// Returns the code to use as the parser method. If not specified, returns + /// None. Otherwise, returns the contents of that code block. Optional getParserCode() const; - // Returns the custom assembly format, if one was specified. + /// Returns the custom assembly format, if one was specified. Optional getAssemblyFormat() const; - // An attribute or type with parameters needs a parser. + /// An attribute or type with parameters needs a parser. bool needsParserPrinter() const { return getNumParameters() != 0; } - // Returns true if this attribute or type has a generated parser. + /// Returns true if this attribute or type has a generated parser. bool hasGeneratedParser() const { return getParserCode() || getAssemblyFormat(); } - // Returns true if this attribute or type has a generated printer. + /// Returns true if this attribute or type has a generated printer. bool hasGeneratedPrinter() const { return getPrinterCode() || getAssemblyFormat(); } - // Returns true if the accessors based on the parameters should be generated. + /// Returns true if the accessors based on the parameters should be generated. bool genAccessors() const; - // Return true if we need to generate the verify declaration and getChecked - // method. + /// Return true if we need to generate the verify declaration and getChecked + /// method. bool genVerifyDecl() const; - // Returns the def's extra class declaration code. + /// Returns the def's extra class declaration code. Optional getExtraDecls() const; - // Get the code location (for error printing). + /// Get the code location (for error printing). ArrayRef getLoc() const; - // Returns true if the default get/getChecked methods should be skipped during - // generation. + /// Returns true if the default get/getChecked methods should be skipped + /// during generation. bool skipDefaultBuilders() const; - // Returns the builders of this def. + /// Returns the builders of this def. ArrayRef getBuilders() const { return builders; } - // Returns the traits of this def. + /// Returns the traits of this def. ArrayRef getTraits() const { return traits; } - // Returns whether two AttrOrTypeDefs are equal by checking the equality of - // the underlying record. + /// Returns whether two AttrOrTypeDefs are equal by checking the equality of + /// the underlying record. bool operator==(const AttrOrTypeDef &other) const; - // Compares two AttrOrTypeDefs by comparing the names of the dialects. + /// Compares two AttrOrTypeDefs by comparing the names of the dialects. bool operator<(const AttrOrTypeDef &other) const; - // Returns whether the AttrOrTypeDef is defined. + /// Returns whether the AttrOrTypeDef is defined. operator bool() const { return def != nullptr; } - // Return the underlying def. + /// Return the underlying def. const llvm::Record *getDef() const { return def; } protected: const llvm::Record *def; - // The builders of this definition. + /// The builders of this definition. SmallVector builders; - // The traits of this definition. + /// The traits of this definition. SmallVector traits; /// The parameters of this attribute or type. @@ -243,8 +251,8 @@ public: using AttrOrTypeDef::AttrOrTypeDef; - // Returns the attributes value type builder code block, or None if it doesn't - // have one. + /// Returns the attributes value type builder code block, or None if it + /// doesn't have one. Optional getTypeBuilder() const; static bool classof(const AttrOrTypeDef *def); diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -177,32 +177,30 @@ // AttrOrTypeParameter //===----------------------------------------------------------------------===// +template +auto AttrOrTypeParameter::getDefValue(StringRef name) const { + Optional().getValue())> result; + if (auto *param = dyn_cast(getDef())) + if (auto *init = param->getDef()->getValue(name)) + if (auto *value = dyn_cast_or_null(init->getValue())) + result = value->getValue(); + return result; +} + StringRef AttrOrTypeParameter::getName() const { return def->getArgName(index)->getValue(); } Optional AttrOrTypeParameter::getAllocator() const { - llvm::Init *parameterType = def->getArg(index); - if (isa(parameterType)) - return Optional(); - if (auto *param = dyn_cast(parameterType)) - return param->getDef()->getValueAsOptionalString("allocator"); - llvm::PrintFatalError("Parameters DAG arguments must be either strings or " - "defs which inherit from AttrOrTypeParameter\n"); + return getDefValue("allocator"); } Optional AttrOrTypeParameter::getComparator() const { - llvm::Init *parameterType = def->getArg(index); - if (isa(parameterType)) - return Optional(); - if (auto *param = dyn_cast(parameterType)) - return param->getDef()->getValueAsOptionalString("comparator"); - llvm::PrintFatalError("Parameters DAG arguments must be either strings or " - "defs which inherit from AttrOrTypeParameter\n"); + return getDefValue("comparator"); } StringRef AttrOrTypeParameter::getCppType() const { - auto *parameterType = def->getArg(index); + llvm::Init *parameterType = getDef(); if (auto *stringType = dyn_cast(parameterType)) return stringType->getValue(); if (auto *param = dyn_cast(parameterType)) @@ -213,74 +211,45 @@ } StringRef AttrOrTypeParameter::getCppAccessorType() const { - if (auto *param = dyn_cast(def->getArg(index))) { - if (Optional type = - param->getDef()->getValueAsOptionalString("cppAccessorType")) - return *type; - } - return getCppType(); + return getDefValue("cppAccessorType") + .getValueOr(getCppType()); } StringRef AttrOrTypeParameter::getCppStorageType() const { - if (auto *param = dyn_cast(def->getArg(index))) { - if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType")) - return *type; - } - return getCppType(); + return getDefValue("cppStorageType") + .getValueOr(getCppType()); } Optional AttrOrTypeParameter::getParser() const { - auto *parameterType = def->getArg(index); - if (auto *param = dyn_cast(parameterType)) { - if (auto parser = param->getDef()->getValueAsOptionalString("parser")) - return *parser; - } - return {}; + return getDefValue("parser"); } Optional AttrOrTypeParameter::getPrinter() const { - auto *parameterType = def->getArg(index); - if (auto *param = dyn_cast(parameterType)) { - if (auto printer = param->getDef()->getValueAsOptionalString("printer")) - return *printer; - } - return {}; + return getDefValue("printer"); } Optional AttrOrTypeParameter::getSummary() const { - auto *parameterType = def->getArg(index); - if (auto *param = dyn_cast(parameterType)) { - const auto *desc = param->getDef()->getValue("summary"); - if (llvm::StringInit *ci = dyn_cast(desc->getValue())) - return ci->getValue(); - } - return Optional(); + return getDefValue("summary"); } StringRef AttrOrTypeParameter::getSyntax() const { - auto *parameterType = def->getArg(index); - if (auto *stringType = dyn_cast(parameterType)) + if (auto *stringType = dyn_cast(getDef())) return stringType->getValue(); - if (auto *param = dyn_cast(parameterType)) { - const auto *syntax = param->getDef()->getValue("syntax"); - if (syntax && isa(syntax->getValue())) - return cast(syntax->getValue())->getValue(); - return getCppType(); - } - llvm::PrintFatalError("Parameters DAG arguments must be either strings or " - "defs which inherit from AttrOrTypeParameter"); + return getDefValue("syntax").getValueOr(getCppType()); } -const llvm::Init *AttrOrTypeParameter::getDef() const { - return def->getArg(index); +bool AttrOrTypeParameter::isOptional() const { + return getDefValue("isOptional").getValueOr(false); } +llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } + //===----------------------------------------------------------------------===// // AttributeSelfTypeParameter //===----------------------------------------------------------------------===// bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) { - const llvm::Init *paramDef = param->getDef(); + llvm::Init *paramDef = param->getDef(); if (auto *paramDefInit = dyn_cast(paramDef)) return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter"); return false; diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -22,7 +22,7 @@ // All of the types will extend this class. class Test_Type traits = []> - : TypeDef; + : TypeDef; def SimpleTypeA : Test_Type<"SimpleA"> { let mnemonic = "smpla"; @@ -42,7 +42,7 @@ ArrayRefParameter< "int", // The parameter C++ type. "An example of an array of ints" // Parameter description. - >: $arrayOfInts + >:$arrayOfInts ); let extraClassDeclaration = [{ @@ -110,14 +110,14 @@ // The parser is defined here also. let parser = [{ - if (parser.parseLess()) return Type(); + if ($_parser.parseLess()) return Type(); SignednessSemantics signedness; - if (parseSignedness($_parser, signedness)) return mlir::Type(); + if (parseSignedness($_parser, signedness)) return Type(); if ($_parser.parseComma()) return Type(); int width; if ($_parser.parseInteger(width)) return Type(); if ($_parser.parseGreater()) return Type(); - ::mlir::Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc()); + Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc()); return getChecked(loc, loc.getContext(), width, signedness); }]; @@ -262,4 +262,66 @@ let assemblyFormat = "`<` struct(params) `>`"; } +def TestTypeOptionalParam : Test_Type<"TestTypeOptionalParam"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, "int":$b); + let mnemonic = "optional_param"; + let assemblyFormat = "`<` $a `,` $b `>`"; +} + +def TestTypeOptionalParams : Test_Type<"TestTypeOptionalParams"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + StringRefParameter<>:$b); + let mnemonic = "optional_params"; + let assemblyFormat = "`<` params `>`"; +} + +def TestTypeOptionalParamsAfterRequired + : Test_Type<"TestTypeOptionalParamsAfterRequired"> { + let parameters = (ins StringRefParameter<>:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "optional_params_after"; + let assemblyFormat = "`<` params `>`"; +} + +def TestTypeOptionalStruct : Test_Type<"TestTypeOptionalStruct"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + StringRefParameter<>:$b); + let mnemonic = "optional_struct"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def TestTypeAllOptionalParams : Test_Type<"TestTypeAllOptionalParams"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "all_optional_params"; + let assemblyFormat = "`<` params `>`"; +} + +def TestTypeAllOptionalStruct : Test_Type<"TestTypeAllOptionalStruct"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "all_optional_struct"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def TestTypeOptionalGroup : Test_Type<"TestTypeOptionalGroup"> { + let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "optional_group"; + let assemblyFormat = "`<` (`(` $b^ `)`) : (`x`)? $a `>`"; +} + +def TestTypeOptionalGroupParams : Test_Type<"TestTypeOptionalGroupParams"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "optional_group_params"; + let assemblyFormat = "`<` (`(` params^ `)`) : (`x`)? `>`"; +} + +def TestTypeOptionalGroupStruct : Test_Type<"TestTypeOptionalGroupStruct"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "optional_group_struct"; + let assemblyFormat = "`<` (`(` struct(params)^ `)`) : (`x`)? `>`"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -64,11 +64,28 @@ return test::CustomParam{value.getValue()}; } }; + inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer, test::CustomParam param) { return printer << param.value; } +/// Overload the attribute parameter parser for optional integers. +template <> +struct FieldParser> { + static FailureOr> parse(AsmParser &parser) { + Optional value; + value.emplace(); + OptionalParseResult result = parser.parseOptionalInteger(*value); + if (result.hasValue()) { + if (succeeded(*result)) + return value; + return failure(); + } + value.reset(); + return value; + } +}; } // namespace mlir #include "TestTypeInterfaces.h.inc" diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td --- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -11,28 +11,28 @@ let mnemonic = asm; } -/// Test format is missing a parameter capture. +// Test format is missing a parameter capture. def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> { let parameters = (ins "int":$v0, "int":$v1); // CHECK: format is missing reference to parameter: v1 let assemblyFormat = "`<` $v0 `>`"; } -/// Test format has duplicate parameter captures. +// Test format has duplicate parameter captures. def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> { let parameters = (ins "int":$v0, "int":$v1); // CHECK: duplicate parameter 'v0' let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`"; } -/// Test format has invalid syntax. +// Test format has invalid syntax. def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> { let parameters = (ins "int":$v0, "int":$v1); // CHECK: expected literal, variable, directive, or optional group let assemblyFormat = "`<` $v0, $v1 `>`"; } -/// Test struct directive has invalid syntax. +// Test struct directive has invalid syntax. def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> { let parameters = (ins "int":$v0); // CHECK: literals may only be used in the top-level section of the format @@ -40,37 +40,70 @@ let assemblyFormat = "`<` struct($v0, `,`) `>`"; } -/// Test struct directive cannot capture zero parameters. +// Test struct directive cannot capture zero parameters. def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> { let parameters = (ins "int":$v0); // CHECK: `struct` argument list expected a variable or directive let assemblyFormat = "`<` struct() $v0 `>`"; } -/// Test capture parameter that does not exist. +// Test capture parameter that does not exist. def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> { let parameters = (ins "int":$v0); // CHECK: InvalidTypeF has no parameter named 'v1' let assemblyFormat = "`<` $v0 $v1 `>`"; } -/// Test duplicate capture of parameter in capture-all struct. +// Test duplicate capture of parameter in capture-all struct. def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> { let parameters = (ins "int":$v0, "int":$v1, "int":$v2); // CHECK: duplicate parameter 'v0' let assemblyFormat = "`<` struct(params) $v0 `>`"; } -/// Test capture-all struct duplicate capture. +// Test capture-all struct duplicate capture. def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> { let parameters = (ins "int":$v0, "int":$v1, "int":$v2); // CHECK: `params` captures duplicate parameter: v0 let assemblyFormat = "`<` $v0 struct(params) `>`"; } -/// Test capture of parameter after `params` directive. +// Test capture of parameter after `params` directive. def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> { let parameters = (ins "int":$v0); // CHECK: duplicate parameter 'v0' let assemblyFormat = "`<` params $v0 `>`"; } + +// Test `struct` with optional parameter followed by comma. +def InvalidTypeJ : InvalidType<"InvalidTypeJ", "invalid_j"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: directive with optional parameters cannot be followed by a comma literal + let assemblyFormat = "struct($a) `,` $b"; +} + +// Test `struct` in optional group must have all optional parameters. +def InvalidTypeK : InvalidType<"InvalidTypeK", "invalid_k"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: is only allowed in an optional group if all captured parameters are optional + let assemblyFormat = "(`(` struct(params)^ `)`)?"; +} + +// Test `struct` in optional group must have all optional parameters. +def InvalidTypeL : InvalidType<"InvalidTypeL", "invalid_l"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: directive allowed in optional group only if all parameters are optional + let assemblyFormat = "(`(` params^ `)`)?"; +} + +def InvalidTypeM : InvalidType<"InvalidTypeM", "invalid_m"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: parameters in an optional group must be optional + let assemblyFormat = "(`(` $a^ `,` $b `)`)?"; +} + +def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> { + let parameters = (ins OptionalParameter<"int">:$a); + // CHECK: optional group anchor must be a parameter or directive + let assemblyFormat = "(`(` $a `)`^)?"; +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -20,4 +20,52 @@ // CHECK-LABEL: @test_roundtrip_default_parsers_struct // CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> // CHECK: !test.struct_capture_all -func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all +// CHECK: !test.optional_param<, 6> +// CHECK: !test.optional_param<5, 6> +// CHECK: !test.optional_params<"a"> +// CHECK: !test.optional_params<5, "a"> +// CHECK: !test.optional_struct +// CHECK: !test.optional_struct +// CHECK: !test.optional_params_after<"a"> +// CHECK: !test.optional_params_after<"a", 5> +// CHECK: !test.all_optional_params<> +// CHECK: !test.all_optional_params<5> +// CHECK: !test.all_optional_params<5, 6> +// CHECK: !test.all_optional_struct<> +// CHECK: !test.all_optional_struct +// CHECK: !test.all_optional_struct +// CHECK: !test.optional_group<(5) 6> +// CHECK: !test.optional_group +// CHECK: !test.optional_group_params +// CHECK: !test.optional_group_params<(5)> +// CHECK: !test.optional_group_params<(5, 6)> +// CHECK: !test.optional_group_struct +// CHECK: !test.optional_group_struct<(b = 5)> +// CHECK: !test.optional_group_struct<(a = 10, b = 5)> +func private @test_roundtrip_default_parsers_struct( + !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> +) -> ( + !test.struct_capture_all, + !test.optional_param<, 6>, + !test.optional_param<5, 6>, + !test.optional_params<"a">, + !test.optional_params<5, "a">, + !test.optional_struct, + !test.optional_struct, + !test.optional_params_after<"a">, + !test.optional_params_after<"a", 5>, + !test.all_optional_params<>, + !test.all_optional_params<5>, + !test.all_optional_params<5, 6>, + !test.all_optional_struct<>, + !test.all_optional_struct, + !test.all_optional_struct, + !test.optional_group<(5) 6>, + !test.optional_group, + !test.optional_group_params, + !test.optional_group_params<(5)>, + !test.optional_group_params<(5, 6)>, + !test.optional_group_struct, + !test.optional_group_struct<(b = 5)>, + !test.optional_group_struct<(b = 5, a = 10)> +) diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -35,38 +35,38 @@ /// Check simple attribute parser and printer are generated correctly. -// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::AsmParser &parser, -// ATTR: ::mlir::Type type) { +// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::AsmParser &odsParser, +// ATTR: ::mlir::Type odsType) { // ATTR: FailureOr _result_value; // ATTR: FailureOr _result_complex; -// ATTR: if (parser.parseKeyword("hello")) +// ATTR: if (odsParser.parseKeyword("hello")) // ATTR: return {}; -// ATTR: if (parser.parseEqual()) +// ATTR: if (odsParser.parseEqual()) // ATTR: return {}; -// ATTR: _result_value = ::mlir::FieldParser::parse(parser); -// ATTR: if (failed(_result_value)) +// ATTR: _result_value = ::mlir::FieldParser::parse(odsParser); +// ATTR: if (::mlir::failed(_result_value)) // ATTR: return {}; -// ATTR: if (parser.parseComma()) +// ATTR: if (odsParser.parseComma()) // ATTR: return {}; -// ATTR: _result_complex = ::parseAttrParamA(parser, type); -// ATTR: if (failed(_result_complex)) +// ATTR: _result_complex = ::parseAttrParamA(odsParser, odsType); +// ATTR: if (::mlir::failed(_result_complex)) // ATTR: return {}; -// ATTR: if (parser.parseRParen()) +// ATTR: if (odsParser.parseRParen()) // ATTR: return {}; -// ATTR: return TestAAttr::get(parser.getContext(), -// ATTR: _result_value.getValue(), -// ATTR: _result_complex.getValue()); +// ATTR: return TestAAttr::get(odsParser.getContext(), +// ATTR: *_result_value, +// ATTR: *_result_complex); // ATTR: } -// ATTR: void TestAAttr::print(::mlir::AsmPrinter &printer) const { -// ATTR: printer << ' ' << "hello"; -// ATTR: printer << ' ' << "="; -// ATTR: printer << ' '; -// ATTR: printer.printStrippedAttrOrType(getValue()); -// ATTR: printer << ","; -// ATTR: printer << ' '; -// ATTR: ::printAttrParamA(printer, getComplex()); -// ATTR: printer << ")"; +// ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const { +// ATTR: odsPrinter << ' ' << "hello"; +// ATTR: odsPrinter << ' ' << "="; +// ATTR: odsPrinter << ' '; +// ATTR: odsPrinter.printStrippedAttrOrType(getValue()); +// ATTR: odsPrinter << ","; +// ATTR: odsPrinter << ' '; +// ATTR: ::printAttrParamA(odsPrinter, getComplex()); +// ATTR: odsPrinter << ")"; // ATTR: } def AttrA : TestAttr<"TestA"> { @@ -81,47 +81,48 @@ /// Test simple struct parser and printer are generated correctly. -// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::AsmParser &parser, -// ATTR: ::mlir::Type type) { +// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::AsmParser &odsParser, +// ATTR: ::mlir::Type odsType) { // ATTR: bool _seen_v0 = false; // ATTR: bool _seen_v1 = false; -// ATTR: for (unsigned _index = 0; _index < 2; ++_index) { -// ATTR: StringRef _paramKey; -// ATTR: if (parser.parseKeyword(&_paramKey)) -// ATTR: return {}; -// ATTR: if (parser.parseEqual()) +// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { +// ATTR: if (odsParser.parseEqual()) // ATTR: return {}; // ATTR: if (!_seen_v0 && _paramKey == "v0") { // ATTR: _seen_v0 = true; -// ATTR: _result_v0 = ::parseAttrParamA(parser, type); -// ATTR: if (failed(_result_v0)) +// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType); +// ATTR: if (::mlir::failed(_result_v0)) // ATTR: return {}; // ATTR: } else if (!_seen_v1 && _paramKey == "v1") { // ATTR: _seen_v1 = true; -// ATTR: _result_v1 = type ? ::parseAttrWithType(parser, type) : ::parseAttrWithout(parser); -// ATTR: if (failed(_result_v1)) +// ATTR: _result_v1 = odsType ? ::parseAttrWithType(odsParser, odsType) : +// ATTR-SAME: ::parseAttrWithout(odsParser); +// ATTR: if (::mlir::failed(_result_v1)) // ATTR: return {}; // ATTR: } else { // ATTR: return {}; // ATTR: } -// ATTR: if ((_index != 2 - 1) && parser.parseComma()) +// ATTR: return true; +// ATTR: } +// ATTR: for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) { +// ATTR: StringRef _paramKey; +// ATTR: if (odsParser.parseKeyword(&_paramKey)) +// ATTR: return {}; +// ATTR: if (!_loop_body(_paramKey)) return {}; +// ATTR: if ((odsStructIndex != 2 - 1) && odsParser.parseComma()) // ATTR: return {}; // ATTR: } -// ATTR: return TestBAttr::get(parser.getContext(), -// ATTR: _result_v0.getValue(), -// ATTR: _result_v1.getValue()); +// ATTR: return TestBAttr::get(odsParser.getContext(), +// ATTR: *_result_v0, +// ATTR: *_result_v1); // ATTR: } -// ATTR: void TestBAttr::print(::mlir::AsmPrinter &printer) const { -// ATTR: printer << "v0"; -// ATTR: printer << ' ' << "="; -// ATTR: printer << ' '; -// ATTR: ::printAttrParamA(printer, getV0()); -// ATTR: printer << ","; -// ATTR: printer << ' ' << "v1"; -// ATTR: printer << ' ' << "="; -// ATTR: printer << ' '; -// ATTR: ::printAttrB(printer, getV1()); +// ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const { +// ATTR: odsPrinter << "v0 = "; +// ATTR: ::printAttrParamA(odsPrinter, getV0()); +// ATTR: odsPrinter << ", "; +// ATTR: odsPrinter << "v1 = "; +// ATTR: ::printAttrB(odsPrinter, getV1()); // ATTR: } def AttrB : TestAttr<"TestB"> { @@ -136,29 +137,21 @@ /// Test attribute with capture-all params has correct parser and printer. -// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::AsmParser &parser, -// ATTR: ::mlir::Type type) { +// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::AsmParser &odsParser, +// ATTR: ::mlir::Type odsType) { // ATTR: ::mlir::FailureOr _result_v0; // ATTR: ::mlir::FailureOr _result_v1; -// ATTR: _result_v0 = ::mlir::FieldParser::parse(parser); -// ATTR: if (failed(_result_v0)) +// ATTR: _result_v0 = ::mlir::FieldParser::parse(odsParser); +// ATTR: if (::mlir::failed(_result_v0)) // ATTR: return {}; -// ATTR: if (parser.parseComma()) +// ATTR: if (odsParser.parseComma()) // ATTR: return {}; -// ATTR: _result_v1 = ::mlir::FieldParser::parse(parser); -// ATTR: if (failed(_result_v1)) +// ATTR: _result_v1 = ::mlir::FieldParser::parse(odsParser); +// ATTR: if (::mlir::failed(_result_v1)) // ATTR: return {}; -// ATTR: return TestFAttr::get(parser.getContext(), -// ATTR: _result_v0.getValue(), -// ATTR: _result_v1.getValue()); -// ATTR: } - -// ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const { -// ATTR: printer << ' '; -// ATTR: printer.printStrippedAttrOrType(getV0()); -// ATTR: printer << ","; -// ATTR: printer << ' '; -// ATTR: printer.printStrippedAttrOrType(getV1()); +// ATTR: return TestFAttr::get(odsParser.getContext(), +// ATTR: *_result_v0, +// ATTR: *_result_v1); // ATTR: } def AttrC : TestAttr<"TestF"> { @@ -171,55 +164,57 @@ /// Test type parser and printer that mix variables and struct are generated /// correctly. -// TYPE: ::mlir::Type TestCType::parse(::mlir::AsmParser &parser) { +// TYPE: ::mlir::Type TestCType::parse(::mlir::AsmParser &odsParser) { // TYPE: FailureOr _result_value; // TYPE: FailureOr _result_complex; -// TYPE: if (parser.parseKeyword("foo")) +// TYPE: if (odsParser.parseKeyword("foo")) // TYPE: return {}; -// TYPE: if (parser.parseComma()) +// TYPE: if (odsParser.parseComma()) // TYPE: return {}; -// TYPE: if (parser.parseColon()) +// TYPE: if (odsParser.parseColon()) // TYPE: return {}; -// TYPE: if (parser.parseKeyword("bob")) +// TYPE: if (odsParser.parseKeyword("bob")) // TYPE: return {}; -// TYPE: if (parser.parseKeyword("bar")) +// TYPE: if (odsParser.parseKeyword("bar")) // TYPE: return {}; -// TYPE: _result_value = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_value)) +// TYPE: _result_value = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_value)) // TYPE: return {}; // TYPE: bool _seen_complex = false; -// TYPE: for (unsigned _index = 0; _index < 1; ++_index) { -// TYPE: StringRef _paramKey; -// TYPE: if (parser.parseKeyword(&_paramKey)) -// TYPE: return {}; +// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { // TYPE: if (!_seen_complex && _paramKey == "complex") { // TYPE: _seen_complex = true; -// TYPE: _result_complex = ::parseTypeParamC(parser); -// TYPE: if (failed(_result_complex)) +// TYPE: _result_complex = ::parseTypeParamC(odsParser); +// TYPE: if (::mlir::failed(_result_complex)) // TYPE: return {}; // TYPE: } else { // TYPE: return {}; // TYPE: } -// TYPE: if ((_index != 1 - 1) && parser.parseComma()) +// TYPE: return true; +// TYPE: } +// TYPE: for (unsigned odsStructIndex = 0; odsStructIndex < 1; ++odsStructIndex) { +// TYPE: StringRef _paramKey; +// TYPE: if (odsParser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (!_loop_body(_paramKey)) return {}; +// TYPE: if ((odsStructIndex != 1 - 1) && odsParser.parseComma()) // TYPE: return {}; // TYPE: } -// TYPE: if (parser.parseRParen()) +// TYPE: if (odsParser.parseRParen()) // TYPE: return {}; // TYPE: } -// TYPE: void TestCType::print(::mlir::AsmPrinter &printer) const { -// TYPE: printer << ' ' << "foo"; -// TYPE: printer << ","; -// TYPE: printer << ' ' << ":"; -// TYPE: printer << ' ' << "bob"; -// TYPE: printer << ' ' << "bar"; -// TYPE: printer << ' '; -// TYPE: printer.printStrippedAttrOrType(getValue()); -// TYPE: printer << ' ' << "complex"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; -// TYPE: printer << getComplex(); -// TYPE: printer << ")"; +// TYPE: void TestCType::print(::mlir::AsmPrinter &odsPrinter) const { +// TYPE: odsPrinter << ' ' << "foo"; +// TYPE: odsPrinter << ","; +// TYPE: odsPrinter << ' ' << ":"; +// TYPE: odsPrinter << ' ' << "bob"; +// TYPE: odsPrinter << ' ' << "bar"; +// TYPE: odsPrinter << ' '; +// TYPE: odsPrinter.printStrippedAttrOrType(getValue()); +// TYPE: odsPrinter << "complex = "; +// TYPE: odsPrinter << getComplex(); +// TYPE: odsPrinter << ")"; // TYPE: } def TypeA : TestType<"TestC"> { @@ -235,51 +230,53 @@ /// Test type parser and printer with mix of variables and struct are generated /// correctly. -// TYPE: ::mlir::Type TestDType::parse(::mlir::AsmParser &parser) { -// TYPE: _result_v0 = ::parseTypeParamC(parser); -// TYPE: if (failed(_result_v0)) +// TYPE: ::mlir::Type TestDType::parse(::mlir::AsmParser &odsParser) { +// TYPE: _result_v0 = ::parseTypeParamC(odsParser); +// TYPE: if (::mlir::failed(_result_v0)) // TYPE: return {}; // TYPE: bool _seen_v1 = false; // TYPE: bool _seen_v2 = false; -// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { -// TYPE: StringRef _paramKey; -// TYPE: if (parser.parseKeyword(&_paramKey)) -// TYPE: return {}; -// TYPE: if (parser.parseEqual()) +// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { +// TYPE: if (odsParser.parseEqual()) // TYPE: return {}; // TYPE: if (!_seen_v1 && _paramKey == "v1") { // TYPE: _seen_v1 = true; // TYPE: _result_v1 = someFcnCall(); -// TYPE: if (failed(_result_v1)) +// TYPE: if (::mlir::failed(_result_v1)) // TYPE: return {}; // TYPE: } else if (!_seen_v2 && _paramKey == "v2") { // TYPE: _seen_v2 = true; -// TYPE: _result_v2 = ::parseTypeParamC(parser); -// TYPE: if (failed(_result_v2)) +// TYPE: _result_v2 = ::parseTypeParamC(odsParser); +// TYPE: if (::mlir::failed(_result_v2)) // TYPE: return {}; // TYPE: } else { // TYPE: return {}; // TYPE: } -// TYPE: if ((_index != 2 - 1) && parser.parseComma()) +// TYPE: return true; +// TYPE: } +// TYPE: for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) { +// TYPE: StringRef _paramKey; +// TYPE: if (odsParser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (!_loop_body(_paramKey)) return {}; +// TYPE: if ((odsStructIndex != 2 - 1) && odsParser.parseComma()) // TYPE: return {}; // TYPE: } // TYPE: _result_v3 = someFcnCall(); -// TYPE: if (failed(_result_v3)) +// TYPE: if (::mlir::failed(_result_v3)) // TYPE: return {}; -// TYPE: return TestDType::get(parser.getContext(), -// TYPE: _result_v0.getValue(), -// TYPE: _result_v1.getValue(), -// TYPE: _result_v2.getValue(), -// TYPE: _result_v3.getValue()); +// TYPE: return TestDType::get(odsParser.getContext(), +// TYPE: *_result_v0, +// TYPE: *_result_v1, +// TYPE: *_result_v2, +// TYPE: *_result_v3); // TYPE: } -// TYPE: void TestDType::print(::mlir::AsmPrinter &printer) const { -// TYPE: printer << getV0(); +// TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const { +// TYPE: odsPrinter << getV0(); // TYPE: myPrinter(getV1()); -// TYPE: printer << ' ' << "v2"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; -// TYPE: printer << getV2(); +// TYPE: odsPrinter << "v2 = "; +// TYPE: odsPrinter << getV2(); // TYPE: myPrinter(getV3()); // TYPE: } @@ -298,85 +295,86 @@ /// Type test with two struct directives has correctly generated parser and /// printer. -// TYPE: ::mlir::Type TestEType::parse(::mlir::AsmParser &parser) { +// TYPE: ::mlir::Type TestEType::parse(::mlir::AsmParser &odsParser) { // TYPE: FailureOr _result_v0; // TYPE: FailureOr _result_v1; // TYPE: FailureOr _result_v2; // TYPE: FailureOr _result_v3; // TYPE: bool _seen_v0 = false; // TYPE: bool _seen_v2 = false; -// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { -// TYPE: StringRef _paramKey; -// TYPE: if (parser.parseKeyword(&_paramKey)) -// TYPE: return {}; -// TYPE: if (parser.parseEqual()) +// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { +// TYPE: if (odsParser.parseEqual()) // TYPE: return {}; // TYPE: if (!_seen_v0 && _paramKey == "v0") { // TYPE: _seen_v0 = true; -// TYPE: _result_v0 = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_v0)) +// TYPE: _result_v0 = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_v0)) // TYPE: return {}; // TYPE: } else if (!_seen_v2 && _paramKey == "v2") { // TYPE: _seen_v2 = true; -// TYPE: _result_v2 = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_v2)) +// TYPE: _result_v2 = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_v2)) // TYPE: return {}; // TYPE: } else { // TYPE: return {}; // TYPE: } -// TYPE: if ((_index != 2 - 1) && parser.parseComma()) +// TYPE: return true; +// TYPE: } +// TYPE: for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) { +// TYPE: StringRef _paramKey; +// TYPE: if (odsParser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (!_loop_body(_paramKey)) return {}; +// TYPE: if ((odsStructIndex != 2 - 1) && odsParser.parseComma()) // TYPE: return {}; // TYPE: } // TYPE: bool _seen_v1 = false; // TYPE: bool _seen_v3 = false; -// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { -// TYPE: StringRef _paramKey; -// TYPE: if (parser.parseKeyword(&_paramKey)) -// TYPE: return {}; -// TYPE: if (parser.parseEqual()) +// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { +// TYPE: if (odsParser.parseEqual()) // TYPE: return {}; // TYPE: if (!_seen_v1 && _paramKey == "v1") { // TYPE: _seen_v1 = true; -// TYPE: _result_v1 = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_v1)) +// TYPE: _result_v1 = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_v1)) // TYPE: return {}; // TYPE: } else if (!_seen_v3 && _paramKey == "v3") { // TYPE: _seen_v3 = true; -// TYPE: _result_v3 = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_v3)) +// TYPE: _result_v3 = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_v3)) // TYPE: return {}; // TYPE: } else { // TYPE: return {}; // TYPE: } -// TYPE: if ((_index != 2 - 1) && parser.parseComma()) +// TYPE: return true; +// TYPE: } +// TYPE: for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) { +// TYPE: StringRef _paramKey; +// TYPE: if (odsParser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (!_loop_body(_paramKey)) return {}; +// TYPE: if ((odsStructIndex != 2 - 1) && odsParser.parseComma()) // TYPE: return {}; // TYPE: } -// TYPE: return TestEType::get(parser.getContext(), -// TYPE: _result_v0.getValue(), -// TYPE: _result_v1.getValue(), -// TYPE: _result_v2.getValue(), -// TYPE: _result_v3.getValue()); +// TYPE: return TestEType::get(odsParser.getContext(), +// TYPE: *_result_v0, +// TYPE: *_result_v1, +// TYPE: *_result_v2, +// TYPE: *_result_v3); // TYPE: } -// TYPE: void TestEType::print(::mlir::AsmPrinter &printer) const { -// TYPE: printer << "v0"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; -// TYPE: printer.printStrippedAttrOrType(getV0()); -// TYPE: printer << ","; -// TYPE: printer << ' ' << "v2"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; -// TYPE: printer.printStrippedAttrOrType(getV2()); -// TYPE: printer << "v1"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; -// TYPE: printer.printStrippedAttrOrType(getV1()); -// TYPE: printer << ","; -// TYPE: printer << ' ' << "v3"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; -// TYPE: printer.printStrippedAttrOrType(getV3()); +// TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const { +// TYPE: odsPrinter << "v0 = "; +// TYPE: odsPrinter.printStrippedAttrOrType(getV0()); +// TYPE: odsPrinter << ", "; +// TYPE: odsPrinter << "v2 = "; +// TYPE: odsPrinter.printStrippedAttrOrType(getV2()); +// TYPE: odsPrinter << ", "; +// TYPE: odsPrinter << "v1 = "; +// TYPE: odsPrinter.printStrippedAttrOrType(getV1()); +// TYPE: odsPrinter << ", "; +// TYPE: odsPrinter << "v3 = "; +// TYPE: odsPrinter.printStrippedAttrOrType(getV3()); // TYPE: } def TypeC : TestType<"TestE"> { @@ -390,3 +388,99 @@ let mnemonic = "type_e"; let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`"; } + +// TYPE: void TestFType::print(::mlir::AsmPrinter &odsPrinter) const { +// TYPE if (getA()) { +// TYPE printer << ' '; +// TYPE printer.printStrippedAttrOrType(getA()); +def TypeD : TestType<"TestF"> { + let parameters = (ins OptionalParameter<"int">:$a); + let mnemonic = "type_f"; + let assemblyFormat = "$a"; +} + +// TYPE: ::mlir::Type TestGType::parse(::mlir::AsmParser &odsParser) { +// TYPE: if (::mlir::failed(_result_a)) +// TYPE: return {}; +// TYPE: if (::mlir::succeeded(_result_a) && *_result_a) +// TYPE: if (odsParser.parseComma()) +// TYPE: return {}; + +// TYPE: if (getA()) +// TYPE: odsPrinter.printStrippedAttrOrType(getA()); +// TYPE: odsPrinter << ", "; +// TYPE: odsPrinter.printStrippedAttrOrType(getB()); + +def TypeE : TestType<"TestG"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + let mnemonic = "type_g"; + let assemblyFormat = "params"; +} + + +// TYPE: ::mlir::Type TestHType::parse(::mlir::AsmParser &odsParser) { +// TYPE: do { +// TYPE: if (!_loop_body(_paramKey)) return {}; +// TYPE: } while(!odsParser.parseOptionalComma()); +// TYPE: if (!_seen_b) +// TYPE: return {}; + +// TYPE: void TestHType::print(::mlir::AsmPrinter &odsPrinter) const { +// TYPE: if (getA()) { +// TYPE: odsPrinter << "a = "; +// TYPE: odsPrinter.printStrippedAttrOrType(getA()); +// TYPE: odsPrinter << ", "; +// TYPE: } + +def TypeF : TestType<"TestH"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + let mnemonic = "type_h"; + let assemblyFormat = "struct(params)"; +} + + +// TYPE: do { +// TYPE: _result_a = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_a)) +// TYPE: return {}; +// TYPE: if (odsParser.parseOptionalComma()) break; +// TYPE: _result_b = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_b)) +// TYPE: return {}; +// TYPE: } while(false); + +def TypeG : TestType<"TestI"> { + let parameters = (ins "int":$a, OptionalParameter<"int">:$b); + let mnemonic = "type_i"; + let assemblyFormat = "params"; +} + +// TYPE: ::mlir::Type TestJType::parse(::mlir::AsmParser &odsParser) { +// TYPE: if (odsParser.parseOptionalLParen()) { +// TYPE: if (odsParser.parseKeyword("x")) return {}; +// TYPE: } else { +// TYPE: _result_b = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_b)) +// TYPE: return {}; +// TYPE: if (odsParser.parseRParen()) return {}; +// TYPE: } +// TYPE: _result_a = ::mlir::FieldParser::parse(odsParser); +// TYPE: if (::mlir::failed(_result_a)) +// TYPE: return {}; + +// TYPE: void TestJType::print(::mlir::AsmPrinter &odsPrinter) const { +// TYPE: if (getB()) { +// TYPE: odsPrinter << "("; +// TYPE: if (getB()) +// TYPE: odsPrinter.printStrippedAttrOrType(getB()); +// TYPE: odsPrinter << ")"; +// TYPE: } else { +// TYPE: odsPrinter << ' ' << "x"; +// TYPE: } +// TYPE: odsPrinter.printStrippedAttrOrType(getA()); + +def TypeH : TestType<"TestJ"> { + let parameters = (ins "int":$a, OptionalParameter<"int">:$b); + let mnemonic = "type_j"; + let assemblyFormat = "(`(` $b^ `)`) : (`x`)? $a"; +} diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -67,8 +67,8 @@ // DECL: return {"cmpnd_a"}; // DECL: } // DECL: static ::mlir::Attribute parse( -// DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type); -// DECL: void print(::mlir::AsmPrinter &printer) const; +// DECL-SAME: ::mlir::AsmParser &odsParser, ::mlir::Type odsType); +// DECL: void print(::mlir::AsmPrinter &odsPrinter) const; // DECL: int getWidthOfSomething() const; // DECL: ::test::SimpleTypeA getExampleTdType() const; // DECL: ::llvm::APFloat getApFloat() const; @@ -107,8 +107,8 @@ // DECL: return {"index"}; // DECL: } // DECL: static ::mlir::Attribute parse( -// DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type); -// DECL: void print(::mlir::AsmPrinter &printer) const; +// DECL-SAME: ::mlir::AsmParser &odsParser, ::mlir::Type odsType); +// DECL: void print(::mlir::AsmPrinter &odsPrinter) const; } def D_SingleParameterAttr : TestAttr<"SingleParameter"> { diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -70,8 +70,8 @@ // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { // DECL: return {"cmpnd_a"}; // DECL: } -// DECL: static ::mlir::Type parse(::mlir::AsmParser &parser); -// DECL: void print(::mlir::AsmPrinter &printer) const; +// DECL: static ::mlir::Type parse(::mlir::AsmParser &odsParser); +// DECL: void print(::mlir::AsmPrinter &odsPrinter) const; // DECL: int getWidthOfSomething() const; // DECL: ::test::SimpleTypeA getExampleTdType() const; // DECL: SomeCppStruct getExampleCppType() const; @@ -89,8 +89,8 @@ // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { // DECL: return {"index"}; // DECL: } -// DECL: static ::mlir::Type parse(::mlir::AsmParser &parser); -// DECL: void print(::mlir::AsmPrinter &printer) const; +// DECL: static ::mlir::Type parse(::mlir::AsmParser &odsParser); +// DECL: void print(::mlir::AsmPrinter &odsPrinter) const; } def D_SingleParameterType : TestType<"SingleParameter"> { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -266,9 +266,9 @@ // Declare the parser. SmallVector parserParams; - parserParams.emplace_back("::mlir::AsmParser &", "parser"); + parserParams.emplace_back("::mlir::AsmParser &", "odsParser"); if (isa(&def)) - parserParams.emplace_back("::mlir::Type", "type"); + parserParams.emplace_back("::mlir::Type", "odsType"); auto *parser = defCls.addMethod( strfmt("::mlir::{0}", valueType), "parse", def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration, @@ -278,7 +278,7 @@ def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration; Method *printer = defCls.addMethod("void", "print", props, - MethodParameter("::mlir::AsmPrinter &", "printer")); + MethodParameter("::mlir::AsmPrinter &", "odsPrinter")); // Emit the bodies. emitParserPrinterBody(parser->body(), printer->body()); } @@ -431,14 +431,15 @@ if (asmFormat) return generateAttrOrTypeFormat(def, parser, printer); - FmtContext ctx = FmtContext( - {{"_parser", "parser"}, {"_printer", "printer"}, {"_type", "type"}}); + FmtContext ctx = FmtContext({{"_parser", "odsParser"}, + {"_printer", "odsPrinter"}, + {"_type", "odsType"}}); if (parserCode) { - ctx.addSubst("_ctxt", "parser.getContext()"); + ctx.addSubst("_ctxt", "odsParser.getContext()"); parser.indent().getStream().printReindented(tgfmt(*parserCode, &ctx).str()); } if (printerCode) { - ctx.addSubst("_ctxt", "printer.getContext()"); + ctx.addSubst("_ctxt", "odsPrinter.getContext()"); printer.indent().getStream().printReindented( tgfmt(*printerCode, &ctx).str()); } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -16,7 +16,9 @@ #include "llvm/ADT/BitVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/TableGenBackend.h" @@ -48,36 +50,47 @@ shouldBeQualifiedFlag = qualified; } + /// Returns true if the element contains an optional parameter. + bool isOptional() const { return param.isOptional(); } + + /// Returns the name of the parameter. + StringRef getName() const { return param.getName(); } + private: bool shouldBeQualifiedFlag = false; AttrOrTypeParameter param; }; +/// Shorthand functions that can be used with ranged-based conditions. +static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); } +static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); } + /// Base class for a directive that contains references to multiple variables. template class ParamsDirectiveBase : public DirectiveElementBase { public: using Base = ParamsDirectiveBase; - ParamsDirectiveBase(std::vector &¶ms) + ParamsDirectiveBase(std::vector &¶ms) : params(std::move(params)) {} /// Get the parameters contained in this directive. - auto getParams() const { - return llvm::map_range(params, [](FormatElement *el) { - return cast(el)->getParam(); - }); - } + ArrayRef getParams() const { return params; } /// Get the number of parameters. unsigned getNumParams() const { return params.size(); } /// Take all of the parameters from this directive. - std::vector takeParams() { return std::move(params); } + std::vector takeParams() { return std::move(params); } + + /// Returns true if there are optional parameters present. + bool hasOptionalParams() const { + return llvm::any_of(getParams(), paramIsOptional); + } private: /// The parameters captured by this directive. - std::vector params; + std::vector params; }; /// This class represents a `params` directive that refers to all parameters @@ -125,36 +138,9 @@ /// Print an error when failing to parse an element. /// /// $0: The parameter C++ class name. -static const char *const parseErrorStr = +static const char *const parserErrorStr = "$_parser.emitError($_parser.getCurrentLocation(), "; -/// Loop declaration for struct parser. -/// -/// $0: Number of expected parameters. -static const char *const structParseLoopStart = R"( - for (unsigned _index = 0; _index < $0; ++_index) { - StringRef _paramKey; - if ($_parser.parseKeyword(&_paramKey)) { - $_parser.emitError($_parser.getCurrentLocation(), - "expected a parameter name in struct"); - return {}; - } -)"; - -/// Terminator code segment for the struct parser loop. Check for duplicate or -/// unknown parameters. Parse a comma except on the last element. -/// -/// {0}: Code template for printing an error. -/// {1}: Number of elements in the struct. -static const char *const structParseLoopEnd = R"({{ - {0}"duplicate or unknown struct parameter name: ") << _paramKey; - return {{}; - } - if ((_index != {1} - 1) && parser.parseComma()) - return {{}; -} -)"; - /// Code format to parse a variable. Separate by lines because variable parsers /// may be generated inside other directives, which requires indentation. /// @@ -166,21 +152,20 @@ static const char *const variableParser = R"( // Parse variable '{0}' _result_{0} = {1}; -if (failed(_result_{0})) {{ +if (::mlir::failed(_result_{0})) {{ {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`"); return {{}; } )"; //===----------------------------------------------------------------------===// -// AttrOrTypeFormat +// DefFormat //===----------------------------------------------------------------------===// namespace { -class AttrOrTypeFormat { +class DefFormat { public: - AttrOrTypeFormat(const AttrOrTypeDef &def, - std::vector &&elements) + DefFormat(const AttrOrTypeDef &def, std::vector &&elements) : def(def), elements(std::move(elements)) {} /// Generate the attribute or type parser. @@ -192,26 +177,36 @@ /// Generate the parser code for a specific format element. void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a literal. - void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os); + void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os, + bool isOptional = false); /// Generate the parser code for a variable. - void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx, - MethodBody &os); + void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a `params` directive. void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a `struct` directive. void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); + /// Generate the parser code for an optional group. + void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, + MethodBody &os); /// Generate the printer code for a specific format element. void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a literal. void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a variable. - void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx, - MethodBody &os, bool printQualified = false); + void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os, + bool skipGuard = false); + /// Generate a printer for comma-separated parameters. + void genCommaSeparatedPrinter(ArrayRef params, + FmtContext &ctx, MethodBody &os, + function_ref extra); /// Generate the printer code for a `params` directive. void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a `struct` directive. void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os); + /// Generate the printer code for an optional group. + void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, + MethodBody &os); /// The ODS definition of the attribute or type whose format is being used to /// generate a parser and printer. @@ -230,35 +225,44 @@ // ParserGen //===----------------------------------------------------------------------===// -void AttrOrTypeFormat::genParser(MethodBody &os) { +void DefFormat::genParser(MethodBody &os) { FmtContext ctx; - ctx.addSubst("_parser", "parser"); + ctx.addSubst("_parser", "odsParser"); if (isa(def)) - ctx.addSubst("_type", "type"); + ctx.addSubst("_type", "odsType"); os.indent(); - /// Declare variables to store all of the parameters. Allocated parameters - /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store - /// FailureOr to defer type construction for parameters that are parsed in - /// a loop (parsers return FailureOr anyways). + // Declare variables to store all of the parameters. Allocated parameters + // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store + // FailureOr to defer type construction for parameters that are parsed in + // a loop (parsers return FailureOr anyways). ArrayRef params = def.getParameters(); for (const AttrOrTypeParameter ¶m : params) { - os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n", + os << formatv("::mlir::FailureOr<{0}> _result_{1};\n", param.getCppStorageType(), param.getName()); } - /// Store the initial location of the parser. - ctx.addSubst("_loc", "loc"); + // Store the initial location of the parser. + ctx.addSubst("_loc", "odsLoc"); os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n" "(void) $_loc;\n", &ctx); - /// Generate call to each parameter parser. + // Generate call to each parameter parser. for (FormatElement *el : elements) genElementParser(el, ctx, os); - /// Generate call to the attribute or type builder. Use the checked getter - /// if one was generated. + // Emit an assert for each mandatory parameter. Triggering an assert means + // the generated parser is incorrect (i.e. there is a bug in this code). + for (const AttrOrTypeParameter ¶m : params) { + if (!param.isOptional()) { + os << formatv("assert(::mlir::succeeded(_result_{0}));\n", + param.getName()); + } + } + + // Generate call to the attribute or type builder. Use the checked getter + // if one was generated. if (def.genVerifyDecl()) { os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()", &ctx, def.getCppClassName()); @@ -266,29 +270,38 @@ os << tgfmt("return $0::get($_parser.getContext()", &ctx, def.getCppClassName()); } - for (const AttrOrTypeParameter ¶m : params) - os << formatv(",\n _result_{0}.getValue()", param.getName()); + for (const AttrOrTypeParameter ¶m : params) { + if (param.isOptional()) + os << formatv(",\n _result_{0}.getValueOr({1}())", param.getName(), + param.getCppStorageType()); + else + os << formatv(",\n *_result_{0}", param.getName()); + } os << ");"; } -void AttrOrTypeFormat::genElementParser(FormatElement *el, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx, + MethodBody &os) { if (auto *literal = dyn_cast(el)) return genLiteralParser(literal->getSpelling(), ctx, os); if (auto *var = dyn_cast(el)) - return genVariableParser(var->getParam(), ctx, os); + return genVariableParser(var, ctx, os); if (auto *params = dyn_cast(el)) return genParamsParser(params, ctx, os); if (auto *strct = dyn_cast(el)) return genStructParser(strct, ctx, os); + if (auto *optional = dyn_cast(el)) + return genOptionalGroupParser(optional, ctx, os); llvm_unreachable("unknown format element"); } -void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx, + MethodBody &os, bool isOptional) { os << "// Parse literal '" << value << "'\n"; os << tgfmt("if ($_parser.parse", &ctx); + if (isOptional) + os << "Optional"; if (value.front() == '_' || isalpha(value.front())) { os << "Keyword(\"" << value << "\")"; } else { @@ -310,70 +323,275 @@ .Case("*", "Star") << "()"; } - os << ")\n"; + if (isOptional) { + // Leave the `if` unclosed to guard optional groups. + return; + } // Parser will emit an error - os << " return {};\n"; + os << ") return {};\n"; } -void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m, - FmtContext &ctx, MethodBody &os) { - /// Check for a custom parser. Use the default attribute parser otherwise. +void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx, + MethodBody &os) { + // Check for a custom parser. Use the default attribute parser otherwise. + const AttrOrTypeParameter ¶m = el->getParam(); auto customParser = param.getParser(); auto parser = customParser ? *customParser : StringRef(defaultParameterParser); os << formatv(variableParser, param.getName(), tgfmt(parser, &ctx, param.getCppStorageType()), - tgfmt(parseErrorStr, &ctx), def.getName(), param.getCppType()); + tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType()); } -void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, + MethodBody &os) { os << "// Parse parameter list\n"; - llvm::interleave( - el->getParams(), - [&](auto param) { this->genVariableParser(param, ctx, os); }, - [&]() { this->genLiteralParser(",", ctx, os); }); + + // If there are optional parameters, we need to switch to `parseOptionalComma` + // if there are no more required parameters after a certain point. + bool hasOptional = el->hasOptionalParams(); + if (hasOptional) { + // Wrap everything in a do-while so that we can `break`. + os << "do {\n"; + os.indent(); + } + + ArrayRef params = el->getParams(); + using IteratorT = ParameterElement *const *; + IteratorT it = params.begin(); + + // Find the last required parameter. Commas become optional aftewards. + // Note: IteratorT's copy assignment is deleted. + ParameterElement *lastReq = nullptr; + for (ParameterElement *param : params) + if (!param->isOptional()) + lastReq = param; + IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin(); + + auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); }; + auto betweenFn = [&](IteratorT it) { + ParameterElement *el = *std::prev(it); + // Parse a comma if the last optional parameter had a value. + if (el->isOptional()) { + os << formatv("if (::mlir::succeeded(_result_{0}) && *_result_{0}) {{\n", + el->getName()); + os.indent(); + } + if (it <= lastReqIt) { + genLiteralParser(",", ctx, os); + } else { + genLiteralParser(",", ctx, os, /*isOptional=*/true); + os << ") break;\n"; + } + if (el->isOptional()) + os.unindent() << "}\n"; + }; + + // llvm::interleave + if (it != params.end()) { + eachFn(*it++); + for (IteratorT e = params.end(); it != e; ++it) { + betweenFn(it); + eachFn(*it); + } + } + + if (hasOptional) + os.unindent() << "} while(false);\n"; } -void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, + MethodBody &os) { + // Loop declaration for struct parser with only required parameters. + // + // $0: Number of expected parameters. + const char *const loopHeader = R"( + for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) { +)"; + + // Loop body start for struct parser. + const char *const loopStart = R"( + ::llvm::StringRef _paramKey; + if ($_parser.parseKeyword(&_paramKey)) { + $_parser.emitError($_parser.getCurrentLocation(), + "expected a parameter name in struct"); + return {}; + } + if (!_loop_body(_paramKey)) return {}; +)"; + + // Struct parser loop end. Check for duplicate or unknown struct parameters. + // + // {0}: Code template for printing an error. + const char *const loopEnd = R"({{ + {0}"duplicate or unknown struct parameter name: ") << _paramKey; + return {{}; +} +)"; + + // Struct parser loop terminator. Parse a comma except on the last element. + // + // {0}: Number of elements in the struct. + const char *const loopTerminator = R"( + if ((odsStructIndex != {0} - 1) && odsParser.parseComma()) + return {{}; +} +)"; + + // Check that a mandatory parameter was parse. + // + // {0}: Name of the parameter. + const char *const checkParam = R"( + if (!_seen_{0}) { + {1}"struct is missing required parameter: ") << "{0}"; + return {{}; + } +)"; + + // Optional parameters in a struct must be parsed successfully if the + // keyword is present. + // + // {0}: Name of the parameter. + // {1}: Emit error string + const char *const checkOptionalParam = R"( + if (::mlir::succeeded(_result_{0}) && !*_result_{0}) {{ + {1}"expected a value for parameter '{0}'"); + return {{}; + } +)"; + + // First iteration of the loop parsing an optional struct. + const char *const optionalStructFirst = R"( + ::llvm::StringRef _paramKey; + if (!$_parser.parseOptionalKeyword(&_paramKey)) { + if (!_loop_body(_paramKey)) return {}; + while (!$_parser.parseOptionalComma()) { +)"; + os << "// Parse parameter struct\n"; - /// Declare a "seen" variable for each key. - for (const AttrOrTypeParameter ¶m : el->getParams()) - os << formatv("bool _seen_{0} = false;\n", param.getName()); + // Declare a "seen" variable for each key. + for (ParameterElement *param : el->getParams()) + os << formatv("bool _seen_{0} = false;\n", param->getName()); - /// Generate the parsing loop. - os.getStream().printReindented( - tgfmt(structParseLoopStart, &ctx, el->getNumParams()).str()); - os.indent(); - genLiteralParser("=", ctx, os); - for (const AttrOrTypeParameter ¶m : el->getParams()) { + // Generate the body of the parsing loop inside a lambda. + os << "{\n"; + os.indent() + << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n"; + genLiteralParser("=", ctx, os.indent()); + for (ParameterElement *param : el->getParams()) { os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n" " _seen_{0} = true;\n", - param.getName()); + param->getName()); genVariableParser(param, ctx, os.indent()); + if (param->isOptional()) { + os.getStream().printReindented(strfmt(checkOptionalParam, + param->getName(), + tgfmt(parserErrorStr, &ctx).str())); + } os.unindent() << "} else "; + // Print the check for duplicate or unknown parameter. } + os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx))); + os << "return true;\n"; + os.unindent() << "};\n"; + + // Generate the parsing loop. If optional parameters are present, then the + // parse loop is guarded by commas. + unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional); + if (numOptional) { + // If the struct itself is optional, pull out the first iteration. + if (numOptional == el->getNumParams()) { + os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str()); + os.indent(); + } else { + os << "do {\n"; + } + } else { + os.getStream().printReindented( + tgfmt(loopHeader, &ctx, el->getNumParams()).str()); + } + os.indent(); + os.getStream().printReindented(tgfmt(loopStart, &ctx).str()); os.unindent(); - /// Duplicate or unknown parameter. - os.getStream().printReindented(strfmt( - structParseLoopEnd, tgfmt(parseErrorStr, &ctx), el->getNumParams())); + // Print the loop terminator. For optional parameters, we have to check that + // all mandatory parameters have been parsed. + // The whole struct is optional if all its parameters are optional. + if (numOptional) { + if (numOptional == el->getNumParams()) { + os << "}\n"; + os.unindent() << "}\n"; + } else { + os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx); + for (ParameterElement *param : el->getParams()) { + if (param->isOptional()) + continue; + os.getStream().printReindented( + strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx))); + } + } + } else { + // Because the loop loops N times and each non-failing iteration sets 1 of + // N flags, successfully exiting the loop means that all parameters have + // been seen. `parseOptionalComma` would cause issues with any formats that + // use "struct(...) `,`" beacuse structs aren't sounded by braces. + os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams())); + } + os.unindent() << "}\n"; +} + +void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, + MethodBody &os) { + ArrayRef elements = + el->getThenElements().drop_front(el->getParseStart()); + + FormatElement *first = elements.front(); + const auto guardOn = [&](auto params) { + os << "if (!("; + llvm::interleave( + params, os, + [&](ParameterElement *el) { + os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})", + el->getName()); + }, + " || "); + os << ")) {\n"; + }; + if (auto *literal = dyn_cast(first)) { + genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true); + os << ") {\n"; + } else if (auto *param = dyn_cast(first)) { + genVariableParser(param, ctx, os); + guardOn(llvm::makeArrayRef(param)); + } else if (auto *params = dyn_cast(first)) { + genParamsParser(params, ctx, os); + guardOn(params->getParams()); + } else { + auto *strct = cast(first); + genStructParser(strct, ctx, os); + guardOn(params->getParams()); + } + os.indent(); - /// Because the loop loops N times and each non-failing iteration sets 1 of - /// N flags, successfully exiting the loop means that all parameters have been - /// seen. `parseOptionalComma` would cause issues with any formats that use - /// "struct(...) `,`" beacuse structs aren't sounded by braces. + // Generate the parsers for the rest of the elements. + for (FormatElement *element : el->getElseElements()) + genElementParser(element, ctx, os); + os.unindent() << "} else {\n"; + os.indent(); + for (FormatElement *element : elements.drop_front()) + genElementParser(element, ctx, os); + os.unindent() << "}\n"; } //===----------------------------------------------------------------------===// // PrinterGen //===----------------------------------------------------------------------===// -void AttrOrTypeFormat::genPrinter(MethodBody &os) { +void DefFormat::genPrinter(MethodBody &os) { FmtContext ctx; - ctx.addSubst("_printer", "printer"); + ctx.addSubst("_printer", "odsPrinter"); + os.indent(); /// Generate printers. shouldEmitSpace = true; @@ -382,8 +600,8 @@ genElementPrinter(el, ctx, os); } -void AttrOrTypeFormat::genElementPrinter(FormatElement *el, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx, + MethodBody &os) { if (auto *literal = dyn_cast(el)) return genLiteralPrinter(literal->getSpelling(), ctx, os); if (auto *params = dyn_cast(el)) @@ -391,63 +609,147 @@ if (auto *strct = dyn_cast(el)) return genStructPrinter(strct, ctx, os); if (auto *var = dyn_cast(el)) - return genVariablePrinter(var->getParam(), ctx, os, - var->shouldBeQualified()); + return genVariablePrinter(var, ctx, os); + if (auto *optional = dyn_cast(el)) + return genOptionalGroupPrinter(optional, ctx, os); - llvm_unreachable("unknown format element"); + llvm::PrintFatalError("unsupported format element"); } -void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, - MethodBody &os) { - /// Don't insert a space before certain punctuation. +void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, + MethodBody &os) { + // Don't insert a space before certain punctuation. bool needSpace = shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation); - os << tgfmt(" $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "", + os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "", value); - /// Update the flags. + // Update the flags. shouldEmitSpace = value.size() != 1 || !StringRef("<({[").contains(value.front()); lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); } -void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m, - FmtContext &ctx, MethodBody &os, - bool printQualified) { - /// Insert a space before the next parameter, if necessary. +void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx, + MethodBody &os, bool skipGuard) { + const AttrOrTypeParameter ¶m = el->getParam(); + ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); + + // Guard the printer on the presence of optional parameters. + if (el->isOptional() && !skipGuard) { + os << tgfmt("if ($_self) {\n", &ctx); + os.indent(); + } + + // Insert a space before the next parameter, if necessary. if (shouldEmitSpace || !lastWasPunctuation) - os << tgfmt(" $_printer << ' ';\n", &ctx); + os << tgfmt("$_printer << ' ';\n", &ctx); shouldEmitSpace = true; lastWasPunctuation = false; - ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); - os << " "; - if (printQualified) + if (el->shouldBeQualified()) os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n"; else if (auto printer = param.getPrinter()) os << tgfmt(*printer, &ctx) << ";\n"; else os << tgfmt(defaultParameterPrinter, &ctx) << ";\n"; + + if (el->isOptional() && !skipGuard) + os.unindent() << "}\n"; } -void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, - MethodBody &os) { - llvm::interleave( - el->getParams(), - [&](auto param) { this->genVariablePrinter(param, ctx, os); }, - [&] { this->genLiteralPrinter(",", ctx, os); }); +void DefFormat::genCommaSeparatedPrinter( + ArrayRef params, FmtContext &ctx, MethodBody &os, + function_ref extra) { + // Emit a space if necessary, but only if the struct is present. + if (shouldEmitSpace || !lastWasPunctuation) { + bool allOptional = llvm::all_of(params, paramIsOptional); + if (allOptional) { + os << "if ("; + llvm::interleave( + params, os, + [&](ParameterElement *param) { + os << getParameterAccessorName(param->getName()) << "()"; + }, + " || "); + os << ") {\n"; + os.indent(); + } + os << tgfmt("$_printer << ' ';\n", &ctx); + if (allOptional) + os.unindent() << "}\n"; + } + + // The first printed element does not need to emit a comma. + os << "{\n"; + os.indent() << "bool _firstPrinted = true;\n"; + for (ParameterElement *param : params) { + if (param->isOptional()) { + os << tgfmt("if ($_self()) {\n", + &ctx.withSelf(getParameterAccessorName(param->getName()))); + os.indent(); + } + os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx); + os << "_firstPrinted = false;\n"; + extra(param); + shouldEmitSpace = false; + lastWasPunctuation = true; + genVariablePrinter(param, ctx, os); + if (param->isOptional()) + os.unindent() << "}\n"; + } + os.unindent() << "}\n"; +} + +void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, + MethodBody &os) { + genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os, + [&](ParameterElement *param) {}); +} + +void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, + MethodBody &os) { + genCommaSeparatedPrinter( + llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) { + os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName()); + }); } -void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, +void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, MethodBody &os) { - llvm::interleave( - el->getParams(), - [&](auto param) { - this->genLiteralPrinter(param.getName(), ctx, os); - this->genLiteralPrinter("=", ctx, os); - this->genVariablePrinter(param, ctx, os); - }, - [&] { this->genLiteralPrinter(",", ctx, os); }); + // Emit the check on whether the group should be printed. + const auto guardOn = [&](auto params) { + os << "if ("; + llvm::interleave( + params, os, + [&](ParameterElement *el) { + os << getParameterAccessorName(el->getName()) << "()"; + }, + " || "); + os << ") {\n"; + os.indent(); + }; + FormatElement *anchor = el->getAnchor(); + if (auto *param = dyn_cast(anchor)) { + guardOn(llvm::makeArrayRef(param)); + } else if (auto *params = dyn_cast(anchor)) { + guardOn(params->getParams()); + } else { + auto *strct = dyn_cast(anchor); + guardOn(strct->getParams()); + } + // Generate the printer for the contained elements. + { + llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace); + llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation); + for (FormatElement *element : el->getThenElements()) + genElementPrinter(element, ctx, os); + } + os.unindent() << "} else {\n"; + os.indent(); + for (FormatElement *element : el->getElseElements()) + genElementPrinter(element, ctx, os); + os.unindent() << "}\n"; } //===----------------------------------------------------------------------===// @@ -462,7 +764,7 @@ seenParams(def.getNumParameters()) {} /// Parse the attribute or type format and create the format elements. - FailureOr parse(); + FailureOr parse(); protected: /// Verify the parsed elements. @@ -476,9 +778,7 @@ /// Verify the elements of an optional group. LogicalResult verifyOptionalGroupElements(SMLoc loc, ArrayRef elements, - Optional anchorIndex) override { - return emitError(loc, "optional groups not (yet) supported"); - } + Optional anchorIndex) override; /// Parse an attribute or type variable. FailureOr parseVariableImpl(SMLoc loc, StringRef name, @@ -505,30 +805,76 @@ LogicalResult DefFormatParser::verify(SMLoc loc, ArrayRef elements) { + // Check that all parameters are referenced in the format. for (auto &it : llvm::enumerate(def.getParameters())) { - if (!seenParams.test(it.index())) { + if (!it.value().isOptional() && !seenParams.test(it.index())) { return emitError(loc, "format is missing reference to parameter: " + it.value().getName()); } } + // A `struct` directive that contains optional parameters cannot be followed + // by a comma literal, which is ambiguous. + for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) { + auto *structEl = dyn_cast(std::get<0>(it)); + auto *literalEl = dyn_cast(std::get<1>(it)); + if (!structEl || !literalEl) + continue; + if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) { + return emitError(loc, "`struct` directive with optional parameters " + "cannot be followed by a comma literal"); + } + } + return success(); +} + +LogicalResult +DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, + ArrayRef elements, + Optional anchorIndex) { + // `params` and `struct` directives are allowed only if all the contained + // parameters are optional. + for (FormatElement *el : elements) { + if (auto *param = dyn_cast(el)) { + if (!param->isOptional()) { + return emitError(loc, + "parameters in an optional group must be optional"); + } + } else if (auto *params = dyn_cast(el)) { + if (llvm::any_of(params->getParams(), paramNotOptional)) { + return emitError(loc, "`params` directive allowed in optional group " + "only if all parameters are optional"); + } + } else if (auto *strct = dyn_cast(el)) { + if (llvm::any_of(strct->getParams(), paramNotOptional)) { + return emitError(loc, "`struct` is only allowed in an optional group " + "if all captured parameters are optional"); + } + } + } + // The anchor must be a parameter or one of the aforementioned directives. + if (anchorIndex && !isa( + elements[*anchorIndex])) { + return emitError(loc, + "optional group anchor must be a parameter or directive"); + } return success(); } -FailureOr DefFormatParser::parse() { +FailureOr DefFormatParser::parse() { FailureOr> elements = FormatParser::parse(); if (failed(elements)) return failure(); - return AttrOrTypeFormat(def, std::move(*elements)); + return DefFormat(def, std::move(*elements)); } FailureOr DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) { - /// Lookup the parameter. + // Lookup the parameter. ArrayRef params = def.getParameters(); auto *it = llvm::find_if( params, [&](auto ¶m) { return param.getName() == name; }); - /// Check that the parameter reference is valid. + // Check that the parameter reference is valid. if (it == params.end()) { return emitError(loc, def.getName() + " has no parameter named '" + name + "'"); @@ -581,9 +927,9 @@ } FailureOr DefFormatParser::parseParamsDirective(SMLoc loc) { - /// Collect all of the attribute's or type's parameters. - std::vector vars; - /// Ensure that none of the parameters have already been captured. + // Collect all of the attribute's or type's parameters. + std::vector vars; + // Ensure that none of the parameters have already been captured. for (const auto &it : llvm::enumerate(def.getParameters())) { if (seenParams.test(it.index())) { return emitError(loc, "`params` captures duplicate parameter: " + @@ -600,27 +946,27 @@ "expected '(' before `struct` argument list"))) return failure(); - /// Parse variables captured by `struct`. - std::vector vars; + // Parse variables captured by `struct`. + std::vector vars; - /// Parse first captured parameter or a `params` directive. + // Parse first captured parameter or a `params` directive. FailureOr var = parseElement(StructDirectiveContext); if (failed(var) || !isa(*var)) { return emitError(loc, "`struct` argument list expected a variable or directive"); } if (isa(*var)) { - /// Parse any other parameters. - vars.push_back(std::move(*var)); + // Parse any other parameters. + vars.push_back(cast(*var)); while (peekToken().is(FormatToken::comma)) { consumeToken(); var = parseElement(StructDirectiveContext); if (failed(var) || !isa(*var)) return emitError(loc, "expected a variable in `struct` argument list"); - vars.push_back(std::move(*var)); + vars.push_back(cast(*var)); } } else { - /// `struct(params)` captures all parameters in the attribute or type. + // `struct(params)` captures all parameters in the attribute or type. vars = cast(*var)->takeParams(); } @@ -642,16 +988,16 @@ mgr.AddNewSourceBuffer( llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc()); - /// Parse the custom assembly format> + // Parse the custom assembly format> DefFormatParser fmtParser(mgr, def); - FailureOr format = fmtParser.parse(); + FailureOr format = fmtParser.parse(); if (failed(format)) { if (formatErrorIsFatal) PrintFatalError(def.getLoc(), "failed to parse assembly format"); return; } - /// Generate the parser and printer. + // Generate the parser and printer. format->genParser(parser); format->genPrinter(printer); }