diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -485,6 +485,37 @@ whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers for these parameters are expected to return `FailureOr<$cppStorageType>`. +#### Optional Parameters + +Optional parameters in the assembly format can be indicated by setting +`isOptional`. The C++ type of an optional parameter is required to satisfy the +following requirements: + +* is default-constructible +* is contextually convertible to `bool` +* only the default-constructed value is `false` + +The parameter parser should return the default-constructed value to indicate "no +value present". The printer will guard on the presence of a value to print the +parameter. + +If a parameter is not parsed, e.g. it is contained in an optional group that was +not present or was not parsed as part of a `struct` directive, then the +attribute or type will be built with its default-constructed value. + +Optional groups can be used with optional parameters. An optional group is a set +of elements optionally printed based on the presence of an anchor. Suppose +parameter `a` is an `IntegerAttr`. + +``` +( `(` $a^ `)` ) : (`x`)? +``` + +In the above assembly format, if `a` is present (non-null), then it will be +printed as `(5 : i32)`. If it is not present, it will be `x`. Directives that +are used inside optional groups are allowed only if all captured parameters are +also optional. + ### Assembly Format Directives Attribute and type assembly formats have the following directives: @@ -547,12 +578,16 @@ !my_dialect.outer_qual> ``` +If optional parameters are present, they are not printed in the parameter list +if they are not present. + #### `struct` Directive The `struct` directive accepts a list of variables to capture and will generate -a parser and printer for a comma-separated list of key-value pairs. The -variables are printed in the order they are specified in the argument list **but -can be parsed in any order**. For example: +a parser and printer for a comma-separated list of key-value pairs. If an +optional parameter is included in the `struct`, it can be elided. The variables +are printed in the order they are specified in the argument list **but can be +parsed in any order**. For example: ```tablegen def MyStructType : TypeDef { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -3116,15 +3116,19 @@ string summary = desc; // The format string for the asm syntax (documentation only). string syntax = ?; - // The default parameter parser is `::mlir::parseField($_parser)`, which - // returns `FailureOr`. Overload `parseField` to support parsing for your - // type. Or you can provide a customer printer. For attributes, "$_type" will - // be replaced with the required attribute type. + // The default parameter parser is `::mlir::FieldParser::parser($_parser)`, + // which returns `FailureOr`. Specialize `FieldParser` to support parsing + // for your type. Or you can provide a customer printer. For attributes, + // "$_type" will be replaced with the required attribute type. string parser = ?; // The default parameter printer is `$_printer << $_self`. Overload the stream // operator of `AsmPrinter` as necessary to print your type. Or you can // provide a custom printer. string printer = ?; + // Mark a parameter as optional. The C++ type of parameters marked as optional + // must be default constructible and be contextually convertible to `bool`. + // Any `Optional` and any attribute type satisfies these requirements. + bit isOptional = 0; } class AttrParameter : AttrOrTypeParameter; @@ -3169,6 +3173,12 @@ }]; } +// An optional parameter. +class OptionalParameter : + AttrOrTypeParameter { + let isOptional = 1; +} + // This is a special parameter used for AttrDefs that represents a `mlir::Type` // that is also used as the value `Type` of the attribute. Only one parameter // of the attribute may be of this type. diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -45,47 +45,50 @@ // AttrOrTypeParameter //===----------------------------------------------------------------------===// -// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to -// AttrOrTypeDefs to parameterize them. +/// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to +/// AttrOrTypeDefs to parameterize them. class AttrOrTypeParameter { public: explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index) : def(def), index(index) {} - // Get the parameter name. + /// Get the parameter name. StringRef getName() const; - // If specified, get the custom allocator code for this parameter. + /// If specified, get the custom allocator code for this parameter. Optional getAllocator() const; - // If specified, get the custom comparator code for this parameter. + /// If specified, get the custom comparator code for this parameter. Optional getComparator() const; - // Get the C++ type of this parameter. + /// Get the C++ type of this parameter. StringRef getCppType() const; - // Get the C++ accessor type of this parameter. + /// Get the C++ accessor type of this parameter. StringRef getCppAccessorType() const; - // Get the C++ storage type of this parameter. + /// Get the C++ storage type of this parameter. StringRef getCppStorageType() const; - // Get an optional C++ parameter parser. + /// Get an optional C++ parameter parser. Optional getParser() const; - // Get an optional C++ parameter printer. + /// Get an optional C++ parameter printer. Optional getPrinter() const; - // Get a description of this parameter for documentation purposes. + /// Get a description of this parameter for documentation purposes. Optional getSummary() const; - // Get the assembly syntax documentation. + /// Get the assembly syntax documentation. StringRef getSyntax() const; - // Return the underlying def of this parameter. - const llvm::Init *getDef() const; + /// Returns true if the parameter is optional. + bool isOptional() const; - // The parameter is pointer-comparable. + /// Return the underlying def of this parameter. + llvm::Init *getDef() const; + + /// The parameter is pointer-comparable. bool operator==(const AttrOrTypeParameter &other) const { return def == other.def && index == other.index; } @@ -94,6 +97,11 @@ } private: + /// A parameter can be either a string or a def. Get a potentially null value + /// from the def. + template + auto getDefValue(StringRef name) const; + /// The underlying tablegen parameter list this parameter is a part of. const llvm::DagInit *def; /// The index of the parameter within the parameter list (`def`). @@ -121,113 +129,113 @@ public: explicit AttrOrTypeDef(const llvm::Record *def); - // Get the dialect for which this def belongs. + /// Get the dialect for which this def belongs. Dialect getDialect() const; - // Returns the name of this AttrOrTypeDef record. + /// Returns the name of this AttrOrTypeDef record. StringRef getName() const; - // Query functions for the documentation of the def. + /// Query functions for the documentation of the def. bool hasDescription() const; StringRef getDescription() const; bool hasSummary() const; StringRef getSummary() const; - // Returns the name of the C++ class to generate. + /// Returns the name of the C++ class to generate. StringRef getCppClassName() const; - // Returns the name of the C++ base class to use when generating this def. + /// Returns the name of the C++ base class to use when generating this def. StringRef getCppBaseClassName() const; - // Returns the name of the storage class for this def. + /// Returns the name of the storage class for this def. StringRef getStorageClassName() const; - // Returns the C++ namespace for this def's storage class. + /// Returns the C++ namespace for this def's storage class. StringRef getStorageNamespace() const; - // Returns true if we should generate the storage class. + /// Returns true if we should generate the storage class. bool genStorageClass() const; - // Indicates whether or not to generate the storage class constructor. + /// Indicates whether or not to generate the storage class constructor. bool hasStorageCustomConstructor() const; /// Get the parameters of this attribute or type. ArrayRef getParameters() const { return parameters; } - // Return the number of parameters + /// Return the number of parameters unsigned getNumParameters() const; - // Return the keyword/mnemonic to use in the printer/parser methods if we are - // supposed to auto-generate them. + /// Return the keyword/mnemonic to use in the printer/parser methods if we are + /// supposed to auto-generate them. Optional getMnemonic() const; - // Returns the code to use as the types printer method. If not specified, - // return a non-value. Otherwise, return the contents of that code block. + /// Returns the code to use as the types printer method. If not specified, + /// return a non-value. Otherwise, return the contents of that code block. Optional getPrinterCode() const; - // Returns the code to use as the parser method. If not specified, returns - // None. Otherwise, returns the contents of that code block. + /// Returns the code to use as the parser method. If not specified, returns + /// None. Otherwise, returns the contents of that code block. Optional getParserCode() const; - // Returns the custom assembly format, if one was specified. + /// Returns the custom assembly format, if one was specified. Optional getAssemblyFormat() const; - // An attribute or type with parameters needs a parser. + /// An attribute or type with parameters needs a parser. bool needsParserPrinter() const { return getNumParameters() != 0; } - // Returns true if this attribute or type has a generated parser. + /// Returns true if this attribute or type has a generated parser. bool hasGeneratedParser() const { return getParserCode() || getAssemblyFormat(); } - // Returns true if this attribute or type has a generated printer. + /// Returns true if this attribute or type has a generated printer. bool hasGeneratedPrinter() const { return getPrinterCode() || getAssemblyFormat(); } - // Returns true if the accessors based on the parameters should be generated. + /// Returns true if the accessors based on the parameters should be generated. bool genAccessors() const; - // Return true if we need to generate the verify declaration and getChecked - // method. + /// Return true if we need to generate the verify declaration and getChecked + /// method. bool genVerifyDecl() const; - // Returns the def's extra class declaration code. + /// Returns the def's extra class declaration code. Optional getExtraDecls() const; - // Get the code location (for error printing). + /// Get the code location (for error printing). ArrayRef getLoc() const; - // Returns true if the default get/getChecked methods should be skipped during - // generation. + /// Returns true if the default get/getChecked methods should be skipped + /// during generation. bool skipDefaultBuilders() const; - // Returns the builders of this def. + /// Returns the builders of this def. ArrayRef getBuilders() const { return builders; } - // Returns the traits of this def. + /// Returns the traits of this def. ArrayRef getTraits() const { return traits; } - // Returns whether two AttrOrTypeDefs are equal by checking the equality of - // the underlying record. + /// Returns whether two AttrOrTypeDefs are equal by checking the equality of + /// the underlying record. bool operator==(const AttrOrTypeDef &other) const; - // Compares two AttrOrTypeDefs by comparing the names of the dialects. + /// Compares two AttrOrTypeDefs by comparing the names of the dialects. bool operator<(const AttrOrTypeDef &other) const; - // Returns whether the AttrOrTypeDef is defined. + /// Returns whether the AttrOrTypeDef is defined. operator bool() const { return def != nullptr; } - // Return the underlying def. + /// Return the underlying def. const llvm::Record *getDef() const { return def; } protected: const llvm::Record *def; - // The builders of this definition. + /// The builders of this definition. SmallVector builders; - // The traits of this definition. + /// The traits of this definition. SmallVector traits; /// The parameters of this attribute or type. @@ -243,8 +251,8 @@ public: using AttrOrTypeDef::AttrOrTypeDef; - // Returns the attributes value type builder code block, or None if it doesn't - // have one. + /// Returns the attributes value type builder code block, or None if it + /// doesn't have one. Optional getTypeBuilder() const; static bool classof(const AttrOrTypeDef *def); diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -177,32 +177,30 @@ // AttrOrTypeParameter //===----------------------------------------------------------------------===// +template +auto AttrOrTypeParameter::getDefValue(StringRef name) const { + Optional().getValue())> result; + if (auto *param = dyn_cast(getDef())) + if (auto *init = param->getDef()->getValue(name)) + if (auto *value = dyn_cast_or_null(init->getValue())) + result = value->getValue(); + return result; +} + StringRef AttrOrTypeParameter::getName() const { return def->getArgName(index)->getValue(); } Optional AttrOrTypeParameter::getAllocator() const { - llvm::Init *parameterType = def->getArg(index); - if (isa(parameterType)) - return Optional(); - if (auto *param = dyn_cast(parameterType)) - return param->getDef()->getValueAsOptionalString("allocator"); - llvm::PrintFatalError("Parameters DAG arguments must be either strings or " - "defs which inherit from AttrOrTypeParameter\n"); + return getDefValue("allocator"); } Optional AttrOrTypeParameter::getComparator() const { - llvm::Init *parameterType = def->getArg(index); - if (isa(parameterType)) - return Optional(); - if (auto *param = dyn_cast(parameterType)) - return param->getDef()->getValueAsOptionalString("comparator"); - llvm::PrintFatalError("Parameters DAG arguments must be either strings or " - "defs which inherit from AttrOrTypeParameter\n"); + return getDefValue("comparator"); } StringRef AttrOrTypeParameter::getCppType() const { - auto *parameterType = def->getArg(index); + llvm::Init *parameterType = getDef(); if (auto *stringType = dyn_cast(parameterType)) return stringType->getValue(); if (auto *param = dyn_cast(parameterType)) @@ -213,74 +211,45 @@ } StringRef AttrOrTypeParameter::getCppAccessorType() const { - if (auto *param = dyn_cast(def->getArg(index))) { - if (Optional type = - param->getDef()->getValueAsOptionalString("cppAccessorType")) - return *type; - } - return getCppType(); + return getDefValue("cppAccessorType") + .getValueOr(getCppType()); } StringRef AttrOrTypeParameter::getCppStorageType() const { - if (auto *param = dyn_cast(def->getArg(index))) { - if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType")) - return *type; - } - return getCppType(); + return getDefValue("cppStorageType") + .getValueOr(getCppType()); } Optional AttrOrTypeParameter::getParser() const { - auto *parameterType = def->getArg(index); - if (auto *param = dyn_cast(parameterType)) { - if (auto parser = param->getDef()->getValueAsOptionalString("parser")) - return *parser; - } - return {}; + return getDefValue("parser"); } Optional AttrOrTypeParameter::getPrinter() const { - auto *parameterType = def->getArg(index); - if (auto *param = dyn_cast(parameterType)) { - if (auto printer = param->getDef()->getValueAsOptionalString("printer")) - return *printer; - } - return {}; + return getDefValue("printer"); } Optional AttrOrTypeParameter::getSummary() const { - auto *parameterType = def->getArg(index); - if (auto *param = dyn_cast(parameterType)) { - const auto *desc = param->getDef()->getValue("summary"); - if (llvm::StringInit *ci = dyn_cast(desc->getValue())) - return ci->getValue(); - } - return Optional(); + return getDefValue("summary"); } StringRef AttrOrTypeParameter::getSyntax() const { - auto *parameterType = def->getArg(index); - if (auto *stringType = dyn_cast(parameterType)) + if (auto *stringType = dyn_cast(getDef())) return stringType->getValue(); - if (auto *param = dyn_cast(parameterType)) { - const auto *syntax = param->getDef()->getValue("syntax"); - if (syntax && isa(syntax->getValue())) - return cast(syntax->getValue())->getValue(); - return getCppType(); - } - llvm::PrintFatalError("Parameters DAG arguments must be either strings or " - "defs which inherit from AttrOrTypeParameter"); + return getDefValue("syntax").getValueOr(getCppType()); } -const llvm::Init *AttrOrTypeParameter::getDef() const { - return def->getArg(index); +bool AttrOrTypeParameter::isOptional() const { + return getDefValue("isOptional").getValueOr(false); } +llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } + //===----------------------------------------------------------------------===// // AttributeSelfTypeParameter //===----------------------------------------------------------------------===// bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) { - const llvm::Init *paramDef = param->getDef(); + llvm::Init *paramDef = param->getDef(); if (auto *paramDefInit = dyn_cast(paramDef)) return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter"); return false; diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -22,7 +22,7 @@ // All of the types will extend this class. class Test_Type traits = []> - : TypeDef; + : TypeDef; def SimpleTypeA : Test_Type<"SimpleA"> { let mnemonic = "smpla"; @@ -42,7 +42,7 @@ ArrayRefParameter< "int", // The parameter C++ type. "An example of an array of ints" // Parameter description. - >: $arrayOfInts + >:$arrayOfInts ); let extraClassDeclaration = [{ @@ -262,4 +262,66 @@ let assemblyFormat = "`<` struct(params) `>`"; } +def TestTypeOptionalParam : Test_Type<"TestTypeOptionalParam"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, "int":$b); + let mnemonic = "optional_param"; + let assemblyFormat = "`<` $a `,` $b `>`"; +} + +def TestTypeOptionalParams : Test_Type<"TestTypeOptionalParams"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + StringRefParameter<>:$b); + let mnemonic = "optional_params"; + let assemblyFormat = "`<` params `>`"; +} + +def TestTypeOptionalParamsAfterRequired + : Test_Type<"TestTypeOptionalParamsAfterRequired"> { + let parameters = (ins StringRefParameter<>:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "optional_params_after"; + let assemblyFormat = "`<` params `>`"; +} + +def TestTypeOptionalStruct : Test_Type<"TestTypeOptionalStruct"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + StringRefParameter<>:$b); + let mnemonic = "optional_struct"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def TestTypeAllOptionalParams : Test_Type<"TestTypeAllOptionalParams"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "all_optional_params"; + let assemblyFormat = "`<` params `>`"; +} + +def TestTypeAllOptionalStruct : Test_Type<"TestTypeAllOptionalStruct"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "all_optional_struct"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def TestTypeOptionalGroup : Test_Type<"TestTypeOptionalGroup"> { + let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "optional_group"; + let assemblyFormat = "`<` (`(` $b^ `)`) : (`x`)? $a `>`"; +} + +def TestTypeOptionalGroupParams : Test_Type<"TestTypeOptionalGroupParams"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "optional_group_params"; + let assemblyFormat = "`<` (`(` params^ `)`) : (`x`)? `>`"; +} + +def TestTypeOptionalGroupStruct : Test_Type<"TestTypeOptionalGroupStruct"> { + let parameters = (ins OptionalParameter<"mlir::Optional">:$a, + OptionalParameter<"mlir::Optional">:$b); + let mnemonic = "optional_group_struct"; + let assemblyFormat = "`<` (`(` struct(params)^ `)`) : (`x`)? `>`"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -64,11 +64,28 @@ return test::CustomParam{value.getValue()}; } }; + inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer, test::CustomParam param) { return printer << param.value; } +/// Overload the attribute parameter parser for optional integers. +template <> +struct FieldParser> { + static FailureOr> parse(AsmParser &parser) { + Optional value; + value.emplace(); + OptionalParseResult result = parser.parseOptionalInteger(*value); + if (result.hasValue()) { + if (succeeded(*result)) + return value; + return failure(); + } + value.reset(); + return value; + } +}; } // namespace mlir #include "TestTypeInterfaces.h.inc" diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td --- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -11,28 +11,28 @@ let mnemonic = asm; } -/// Test format is missing a parameter capture. +// Test format is missing a parameter capture. def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> { let parameters = (ins "int":$v0, "int":$v1); // CHECK: format is missing reference to parameter: v1 let assemblyFormat = "`<` $v0 `>`"; } -/// Test format has duplicate parameter captures. +// Test format has duplicate parameter captures. def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> { let parameters = (ins "int":$v0, "int":$v1); // CHECK: duplicate parameter 'v0' let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`"; } -/// Test format has invalid syntax. +// Test format has invalid syntax. def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> { let parameters = (ins "int":$v0, "int":$v1); // CHECK: expected literal, variable, directive, or optional group let assemblyFormat = "`<` $v0, $v1 `>`"; } -/// Test struct directive has invalid syntax. +// Test struct directive has invalid syntax. def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> { let parameters = (ins "int":$v0); // CHECK: literals may only be used in the top-level section of the format @@ -40,37 +40,70 @@ let assemblyFormat = "`<` struct($v0, `,`) `>`"; } -/// Test struct directive cannot capture zero parameters. +// Test struct directive cannot capture zero parameters. def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> { let parameters = (ins "int":$v0); // CHECK: `struct` argument list expected a variable or directive let assemblyFormat = "`<` struct() $v0 `>`"; } -/// Test capture parameter that does not exist. +// Test capture parameter that does not exist. def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> { let parameters = (ins "int":$v0); // CHECK: InvalidTypeF has no parameter named 'v1' let assemblyFormat = "`<` $v0 $v1 `>`"; } -/// Test duplicate capture of parameter in capture-all struct. +// Test duplicate capture of parameter in capture-all struct. def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> { let parameters = (ins "int":$v0, "int":$v1, "int":$v2); // CHECK: duplicate parameter 'v0' let assemblyFormat = "`<` struct(params) $v0 `>`"; } -/// Test capture-all struct duplicate capture. +// Test capture-all struct duplicate capture. def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> { let parameters = (ins "int":$v0, "int":$v1, "int":$v2); // CHECK: `params` captures duplicate parameter: v0 let assemblyFormat = "`<` $v0 struct(params) `>`"; } -/// Test capture of parameter after `params` directive. +// Test capture of parameter after `params` directive. def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> { let parameters = (ins "int":$v0); // CHECK: duplicate parameter 'v0' let assemblyFormat = "`<` params $v0 `>`"; } + +// Test `struct` with optional parameter followed by comma. +def InvalidTypeJ : InvalidType<"InvalidTypeJ", "invalid_j"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: directive with optional parameters cannot be followed by a comma literal + let assemblyFormat = "struct($a) `,` $b"; +} + +// Test `struct` in optional group must have all optional parameters. +def InvalidTypeK : InvalidType<"InvalidTypeK", "invalid_k"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: is only allowed in an optional group if all captured parameters are optional + let assemblyFormat = "(`(` struct(params)^ `)`)?"; +} + +// Test `struct` in optional group must have all optional parameters. +def InvalidTypeL : InvalidType<"InvalidTypeL", "invalid_l"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: directive allowed in optional group only if all parameters are optional + let assemblyFormat = "(`(` params^ `)`)?"; +} + +def InvalidTypeM : InvalidType<"InvalidTypeM", "invalid_m"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: parameters in an optional group must be optional + let assemblyFormat = "(`(` $a^ `,` $b `)`)?"; +} + +def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> { + let parameters = (ins OptionalParameter<"int">:$a); + // CHECK: optional group anchor must be a parameter or directive + let assemblyFormat = "(`(` $a `)`^)?"; +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -20,4 +20,52 @@ // CHECK-LABEL: @test_roundtrip_default_parsers_struct // CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> // CHECK: !test.struct_capture_all -func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all +// CHECK: !test.optional_param<, 6> +// CHECK: !test.optional_param<5, 6> +// CHECK: !test.optional_params<"a"> +// CHECK: !test.optional_params<5, "a"> +// CHECK: !test.optional_struct +// CHECK: !test.optional_struct +// CHECK: !test.optional_params_after<"a"> +// CHECK: !test.optional_params_after<"a", 5> +// CHECK: !test.all_optional_params<> +// CHECK: !test.all_optional_params<5> +// CHECK: !test.all_optional_params<5, 6> +// CHECK: !test.all_optional_struct<> +// CHECK: !test.all_optional_struct +// CHECK: !test.all_optional_struct +// CHECK: !test.optional_group<(5) 6> +// CHECK: !test.optional_group +// CHECK: !test.optional_group_params +// CHECK: !test.optional_group_params<(5)> +// CHECK: !test.optional_group_params<(5, 6)> +// CHECK: !test.optional_group_struct +// CHECK: !test.optional_group_struct<(b = 5)> +// CHECK: !test.optional_group_struct<(a = 10, b = 5)> +func private @test_roundtrip_default_parsers_struct( + !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> +) -> ( + !test.struct_capture_all, + !test.optional_param<, 6>, + !test.optional_param<5, 6>, + !test.optional_params<"a">, + !test.optional_params<5, "a">, + !test.optional_struct, + !test.optional_struct, + !test.optional_params_after<"a">, + !test.optional_params_after<"a", 5>, + !test.all_optional_params<>, + !test.all_optional_params<5>, + !test.all_optional_params<5, 6>, + !test.all_optional_struct<>, + !test.all_optional_struct, + !test.all_optional_struct, + !test.optional_group<(5) 6>, + !test.optional_group, + !test.optional_group_params, + !test.optional_group_params<(5)>, + !test.optional_group_params<(5, 6)>, + !test.optional_group_struct, + !test.optional_group_struct<(b = 5)>, + !test.optional_group_struct<(b = 5, a = 10)> +) diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -44,18 +44,18 @@ // ATTR: if (parser.parseEqual()) // ATTR: return {}; // ATTR: _result_value = ::mlir::FieldParser::parse(parser); -// ATTR: if (failed(_result_value)) +// ATTR: if (::mlir::failed(_result_value)) // ATTR: return {}; // ATTR: if (parser.parseComma()) // ATTR: return {}; // ATTR: _result_complex = ::parseAttrParamA(parser, type); -// ATTR: if (failed(_result_complex)) +// ATTR: if (::mlir::failed(_result_complex)) // ATTR: return {}; // ATTR: if (parser.parseRParen()) // ATTR: return {}; // ATTR: return TestAAttr::get(parser.getContext(), -// ATTR: _result_value.getValue(), -// ATTR: _result_complex.getValue()); +// ATTR: *_result_value, +// ATTR: *_result_complex); // ATTR: } // ATTR: void TestAAttr::print(::mlir::AsmPrinter &printer) const { @@ -85,42 +85,42 @@ // ATTR: ::mlir::Type type) { // ATTR: bool _seen_v0 = false; // ATTR: bool _seen_v1 = false; -// ATTR: for (unsigned _index = 0; _index < 2; ++_index) { -// ATTR: StringRef _paramKey; -// ATTR: if (parser.parseKeyword(&_paramKey)) -// ATTR: return {}; +// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { // ATTR: if (parser.parseEqual()) // ATTR: return {}; // ATTR: if (!_seen_v0 && _paramKey == "v0") { // ATTR: _seen_v0 = true; // ATTR: _result_v0 = ::parseAttrParamA(parser, type); -// ATTR: if (failed(_result_v0)) +// ATTR: if (::mlir::failed(_result_v0)) // ATTR: return {}; // ATTR: } else if (!_seen_v1 && _paramKey == "v1") { // ATTR: _seen_v1 = true; // ATTR: _result_v1 = type ? ::parseAttrWithType(parser, type) : ::parseAttrWithout(parser); -// ATTR: if (failed(_result_v1)) +// ATTR: if (::mlir::failed(_result_v1)) // ATTR: return {}; // ATTR: } else { // ATTR: return {}; // ATTR: } +// ATTR: return true; +// ATTR: } +// ATTR: for (unsigned _index = 0; _index < 2; ++_index) { +// ATTR: StringRef _paramKey; +// ATTR: if (parser.parseKeyword(&_paramKey)) +// ATTR: return {}; +// ATTR: if (!_loop_body(_paramKey)) return {}; // ATTR: if ((_index != 2 - 1) && parser.parseComma()) // ATTR: return {}; // ATTR: } // ATTR: return TestBAttr::get(parser.getContext(), -// ATTR: _result_v0.getValue(), -// ATTR: _result_v1.getValue()); +// ATTR: *_result_v0, +// ATTR: *_result_v1); // ATTR: } // ATTR: void TestBAttr::print(::mlir::AsmPrinter &printer) const { -// ATTR: printer << "v0"; -// ATTR: printer << ' ' << "="; -// ATTR: printer << ' '; +// ATTR: printer << "v0 = "; // ATTR: ::printAttrParamA(printer, getV0()); -// ATTR: printer << ","; -// ATTR: printer << ' ' << "v1"; -// ATTR: printer << ' ' << "="; -// ATTR: printer << ' '; +// ATTR: printer << ", "; +// ATTR: printer << "v1 = "; // ATTR: ::printAttrB(printer, getV1()); // ATTR: } @@ -141,24 +141,16 @@ // ATTR: ::mlir::FailureOr _result_v0; // ATTR: ::mlir::FailureOr _result_v1; // ATTR: _result_v0 = ::mlir::FieldParser::parse(parser); -// ATTR: if (failed(_result_v0)) +// ATTR: if (::mlir::failed(_result_v0)) // ATTR: return {}; // ATTR: if (parser.parseComma()) // ATTR: return {}; // ATTR: _result_v1 = ::mlir::FieldParser::parse(parser); -// ATTR: if (failed(_result_v1)) +// ATTR: if (::mlir::failed(_result_v1)) // ATTR: return {}; // ATTR: return TestFAttr::get(parser.getContext(), -// ATTR: _result_v0.getValue(), -// ATTR: _result_v1.getValue()); -// ATTR: } - -// ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const { -// ATTR: printer << ' '; -// ATTR: printer.printStrippedAttrOrType(getV0()); -// ATTR: printer << ","; -// ATTR: printer << ' '; -// ATTR: printer.printStrippedAttrOrType(getV1()); +// ATTR: *_result_v0, +// ATTR: *_result_v1); // ATTR: } def AttrC : TestAttr<"TestF"> { @@ -185,21 +177,25 @@ // TYPE: if (parser.parseKeyword("bar")) // TYPE: return {}; // TYPE: _result_value = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_value)) +// TYPE: if (::mlir::failed(_result_value)) // TYPE: return {}; // TYPE: bool _seen_complex = false; -// TYPE: for (unsigned _index = 0; _index < 1; ++_index) { -// TYPE: StringRef _paramKey; -// TYPE: if (parser.parseKeyword(&_paramKey)) -// TYPE: return {}; +// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { // TYPE: if (!_seen_complex && _paramKey == "complex") { // TYPE: _seen_complex = true; // TYPE: _result_complex = ::parseTypeParamC(parser); -// TYPE: if (failed(_result_complex)) +// TYPE: if (::mlir::failed(_result_complex)) // TYPE: return {}; // TYPE: } else { // TYPE: return {}; // TYPE: } +// TYPE: return true; +// TYPE: } +// TYPE: for (unsigned _index = 0; _index < 1; ++_index) { +// TYPE: StringRef _paramKey; +// TYPE: if (parser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (!_loop_body(_paramKey)) return {}; // TYPE: if ((_index != 1 - 1) && parser.parseComma()) // TYPE: return {}; // TYPE: } @@ -215,9 +211,7 @@ // TYPE: printer << ' ' << "bar"; // TYPE: printer << ' '; // TYPE: printer.printStrippedAttrOrType(getValue()); -// TYPE: printer << ' ' << "complex"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; +// TYPE: printer << "complex = "; // TYPE: printer << getComplex(); // TYPE: printer << ")"; // TYPE: } @@ -237,48 +231,50 @@ // TYPE: ::mlir::Type TestDType::parse(::mlir::AsmParser &parser) { // TYPE: _result_v0 = ::parseTypeParamC(parser); -// TYPE: if (failed(_result_v0)) +// TYPE: if (::mlir::failed(_result_v0)) // TYPE: return {}; // TYPE: bool _seen_v1 = false; // TYPE: bool _seen_v2 = false; -// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { -// TYPE: StringRef _paramKey; -// TYPE: if (parser.parseKeyword(&_paramKey)) -// TYPE: return {}; +// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { // TYPE: if (parser.parseEqual()) // TYPE: return {}; // TYPE: if (!_seen_v1 && _paramKey == "v1") { // TYPE: _seen_v1 = true; // TYPE: _result_v1 = someFcnCall(); -// TYPE: if (failed(_result_v1)) +// TYPE: if (::mlir::failed(_result_v1)) // TYPE: return {}; // TYPE: } else if (!_seen_v2 && _paramKey == "v2") { // TYPE: _seen_v2 = true; // TYPE: _result_v2 = ::parseTypeParamC(parser); -// TYPE: if (failed(_result_v2)) +// TYPE: if (::mlir::failed(_result_v2)) // TYPE: return {}; // TYPE: } else { // TYPE: return {}; // TYPE: } +// TYPE: return true; +// TYPE: } +// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { +// TYPE: StringRef _paramKey; +// TYPE: if (parser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (!_loop_body(_paramKey)) return {}; // TYPE: if ((_index != 2 - 1) && parser.parseComma()) // TYPE: return {}; // TYPE: } // TYPE: _result_v3 = someFcnCall(); -// TYPE: if (failed(_result_v3)) +// TYPE: if (::mlir::failed(_result_v3)) // TYPE: return {}; // TYPE: return TestDType::get(parser.getContext(), -// TYPE: _result_v0.getValue(), -// TYPE: _result_v1.getValue(), -// TYPE: _result_v2.getValue(), -// TYPE: _result_v3.getValue()); +// TYPE: *_result_v0, +// TYPE: *_result_v1, +// TYPE: *_result_v2, +// TYPE: *_result_v3); // TYPE: } // TYPE: void TestDType::print(::mlir::AsmPrinter &printer) const { // TYPE: printer << getV0(); // TYPE: myPrinter(getV1()); -// TYPE: printer << ' ' << "v2"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; +// TYPE: printer << "v2 = "; // TYPE: printer << getV2(); // TYPE: myPrinter(getV3()); // TYPE: } @@ -305,77 +301,78 @@ // TYPE: FailureOr _result_v3; // TYPE: bool _seen_v0 = false; // TYPE: bool _seen_v2 = false; -// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { -// TYPE: StringRef _paramKey; -// TYPE: if (parser.parseKeyword(&_paramKey)) -// TYPE: return {}; +// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { // TYPE: if (parser.parseEqual()) // TYPE: return {}; // TYPE: if (!_seen_v0 && _paramKey == "v0") { // TYPE: _seen_v0 = true; // TYPE: _result_v0 = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_v0)) +// TYPE: if (::mlir::failed(_result_v0)) // TYPE: return {}; // TYPE: } else if (!_seen_v2 && _paramKey == "v2") { // TYPE: _seen_v2 = true; // TYPE: _result_v2 = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_v2)) +// TYPE: if (::mlir::failed(_result_v2)) // TYPE: return {}; // TYPE: } else { // TYPE: return {}; // TYPE: } -// TYPE: if ((_index != 2 - 1) && parser.parseComma()) -// TYPE: return {}; +// TYPE: return true; // TYPE: } -// TYPE: bool _seen_v1 = false; -// TYPE: bool _seen_v3 = false; // TYPE: for (unsigned _index = 0; _index < 2; ++_index) { // TYPE: StringRef _paramKey; // TYPE: if (parser.parseKeyword(&_paramKey)) // TYPE: return {}; +// TYPE: if (!_loop_body(_paramKey)) return {}; +// TYPE: if ((_index != 2 - 1) && parser.parseComma()) +// TYPE: return {}; +// TYPE: } +// TYPE: bool _seen_v1 = false; +// TYPE: bool _seen_v3 = false; +// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { // TYPE: if (parser.parseEqual()) // TYPE: return {}; // TYPE: if (!_seen_v1 && _paramKey == "v1") { // TYPE: _seen_v1 = true; // TYPE: _result_v1 = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_v1)) +// TYPE: if (::mlir::failed(_result_v1)) // TYPE: return {}; // TYPE: } else if (!_seen_v3 && _paramKey == "v3") { // TYPE: _seen_v3 = true; // TYPE: _result_v3 = ::mlir::FieldParser::parse(parser); -// TYPE: if (failed(_result_v3)) +// TYPE: if (::mlir::failed(_result_v3)) // TYPE: return {}; // TYPE: } else { // TYPE: return {}; // TYPE: } +// TYPE: return true; +// TYPE: } +// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { +// TYPE: StringRef _paramKey; +// TYPE: if (parser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (!_loop_body(_paramKey)) return {}; // TYPE: if ((_index != 2 - 1) && parser.parseComma()) // TYPE: return {}; // TYPE: } // TYPE: return TestEType::get(parser.getContext(), -// TYPE: _result_v0.getValue(), -// TYPE: _result_v1.getValue(), -// TYPE: _result_v2.getValue(), -// TYPE: _result_v3.getValue()); +// TYPE: *_result_v0, +// TYPE: *_result_v1, +// TYPE: *_result_v2, +// TYPE: *_result_v3); // TYPE: } // TYPE: void TestEType::print(::mlir::AsmPrinter &printer) const { -// TYPE: printer << "v0"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; +// TYPE: printer << "v0 = "; // TYPE: printer.printStrippedAttrOrType(getV0()); -// TYPE: printer << ","; -// TYPE: printer << ' ' << "v2"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; +// TYPE: printer << ", "; +// TYPE: printer << "v2 = "; // TYPE: printer.printStrippedAttrOrType(getV2()); -// TYPE: printer << "v1"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; +// TYPE: printer << ", "; +// TYPE: printer << "v1 = "; // TYPE: printer.printStrippedAttrOrType(getV1()); -// TYPE: printer << ","; -// TYPE: printer << ' ' << "v3"; -// TYPE: printer << ' ' << "="; -// TYPE: printer << ' '; +// TYPE: printer << ", "; +// TYPE: printer << "v3 = "; // TYPE: printer.printStrippedAttrOrType(getV3()); // TYPE: } @@ -390,3 +387,99 @@ let mnemonic = "type_e"; let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`"; } + +// TYPE: void TestFType::print(::mlir::AsmPrinter &printer) const { +// TYPE if (getA()) { +// TYPE printer << ' '; +// TYPE printer.printStrippedAttrOrType(getA()); +def TypeD : TestType<"TestF"> { + let parameters = (ins OptionalParameter<"int">:$a); + let mnemonic = "type_f"; + let assemblyFormat = "$a"; +} + +// TYPE: ::mlir::Type TestGType::parse(::mlir::AsmParser &parser) { +// TYPE: if (::mlir::failed(_result_a)) +// TYPE: return {}; +// TYPE: if (::mlir::succeeded(_result_a) && *_result_a) +// TYPE: if (parser.parseComma()) +// TYPE: return {}; + +// TYPE: if (getA()) +// TYPE: printer.printStrippedAttrOrType(getA()); +// TYPE: printer << ", "; +// TYPE: printer.printStrippedAttrOrType(getB()); + +def TypeE : TestType<"TestG"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + let mnemonic = "type_g"; + let assemblyFormat = "params"; +} + + +// TYPE: ::mlir::Type TestHType::parse(::mlir::AsmParser &parser) { +// TYPE: do { +// TYPE: if (!_loop_body(_paramKey)) return {}; +// TYPE: } while(!parser.parseOptionalComma()); +// TYPE: if (!_seen_b) +// TYPE: return {}; + +// TYPE: void TestHType::print(::mlir::AsmPrinter &printer) const { +// TYPE: if (getA()) { +// TYPE: printer << "a = "; +// TYPE: printer.printStrippedAttrOrType(getA()); +// TYPE: printer << ", "; +// TYPE: } + +def TypeF : TestType<"TestH"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + let mnemonic = "type_h"; + let assemblyFormat = "struct(params)"; +} + + +// TYPE: do { +// TYPE: _result_a = ::mlir::FieldParser::parse(parser); +// TYPE: if (::mlir::failed(_result_a)) +// TYPE: return {}; +// TYPE: if (parser.parseOptionalComma()) break; +// TYPE: _result_b = ::mlir::FieldParser::parse(parser); +// TYPE: if (::mlir::failed(_result_b)) +// TYPE: return {}; +// TYPE: } while(false); + +def TypeG : TestType<"TestI"> { + let parameters = (ins "int":$a, OptionalParameter<"int">:$b); + let mnemonic = "type_i"; + let assemblyFormat = "params"; +} + +// TYPE: ::mlir::Type TestJType::parse(::mlir::AsmParser &parser) { +// TYPE: if (parser.parseOptionalLParen()) { +// TYPE: if (parser.parseKeyword("x")) return {}; +// TYPE: } else { +// TYPE: _result_b = ::mlir::FieldParser::parse(parser); +// TYPE: if (::mlir::failed(_result_b)) +// TYPE: return {}; +// TYPE: if (parser.parseRParen()) return {}; +// TYPE: } +// TYPE: _result_a = ::mlir::FieldParser::parse(parser); +// TYPE: if (::mlir::failed(_result_a)) +// TYPE: return {}; + +// TYPE: void TestJType::print(::mlir::AsmPrinter &printer) const { +// TYPE: if (getB()) { +// TYPE: printer << "("; +// TYPE: if (getB()) +// TYPE: printer.printStrippedAttrOrType(getB()); +// TYPE: printer << ")"; +// TYPE: } else { +// TYPE: printer << ' ' << "x"; +// TYPE: } +// TYPE: printer.printStrippedAttrOrType(getA()); + +def TypeH : TestType<"TestJ"> { + let parameters = (ins "int":$a, OptionalParameter<"int">:$b); + let mnemonic = "type_j"; + let assemblyFormat = "(`(` $b^ `)`) : (`x`)? $a"; +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -14,9 +14,11 @@ #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/BitVector.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/TableGenBackend.h" @@ -48,11 +50,21 @@ shouldBeQualifiedFlag = qualified; } + /// Returns true if the element contains an optional parameter. + bool isOptional() const { return param.isOptional(); } + + /// Returns the name of the parameter. + StringRef getName() const { return param.getName(); } + private: bool shouldBeQualifiedFlag = false; AttrOrTypeParameter param; }; +/// Shorthand functions that can be used with ranged-based conditions. +static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); } +static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); } + /// Base class for a directive that contains references to multiple variables. template class ParamsDirectiveBase : public DirectiveElementBase { @@ -64,17 +76,19 @@ /// Get the parameters contained in this directive. auto getParams() const { - return llvm::map_range(params, [](FormatElement *el) { - return cast(el)->getParam(); - }); + return llvm::map_range( + params, [](FormatElement *el) { return cast(el); }); } /// Get the number of parameters. unsigned getNumParams() const { return params.size(); } /// Take all of the parameters from this directive. - std::vector takeParams() { - return std::move(params); + std::vector takeParams() { return std::move(params); } + + /// Returns true if there are optional parameters present. + bool hasOptionalParams() const { + return llvm::any_of(getParams(), paramIsOptional); } private: @@ -127,36 +141,9 @@ /// Print an error when failing to parse an element. /// /// $0: The parameter C++ class name. -static const char *const parseErrorStr = +static const char *const parserErrorStr = "$_parser.emitError($_parser.getCurrentLocation(), "; -/// Loop declaration for struct parser. -/// -/// $0: Number of expected parameters. -static const char *const structParseLoopStart = R"( - for (unsigned _index = 0; _index < $0; ++_index) { - StringRef _paramKey; - if ($_parser.parseKeyword(&_paramKey)) { - $_parser.emitError($_parser.getCurrentLocation(), - "expected a parameter name in struct"); - return {}; - } -)"; - -/// Terminator code segment for the struct parser loop. Check for duplicate or -/// unknown parameters. Parse a comma except on the last element. -/// -/// {0}: Code template for printing an error. -/// {1}: Number of elements in the struct. -static const char *const structParseLoopEnd = R"({{ - {0}"duplicate or unknown struct parameter name: ") << _paramKey; - return {{}; - } - if ((_index != {1} - 1) && parser.parseComma()) - return {{}; -} -)"; - /// Code format to parse a variable. Separate by lines because variable parsers /// may be generated inside other directives, which requires indentation. /// @@ -168,21 +155,20 @@ static const char *const variableParser = R"( // Parse variable '{0}' _result_{0} = {1}; -if (failed(_result_{0})) {{ +if (::mlir::failed(_result_{0})) {{ {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`"); return {{}; } )"; //===----------------------------------------------------------------------===// -// AttrOrTypeFormat +// DefFormat //===----------------------------------------------------------------------===// namespace { -class AttrOrTypeFormat { +class DefFormat { public: - AttrOrTypeFormat(const AttrOrTypeDef &def, - std::vector &&elements) + DefFormat(const AttrOrTypeDef &def, std::vector &&elements) : def(def), elements(std::move(elements)) {} /// Generate the attribute or type parser. @@ -194,26 +180,36 @@ /// Generate the parser code for a specific format element. void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a literal. - void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os); + void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os, + bool isOptional = false); /// Generate the parser code for a variable. - void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx, - MethodBody &os); + void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a `params` directive. void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a `struct` directive. void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); + /// Generate the parser code for an optional group. + void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, + MethodBody &os); /// Generate the printer code for a specific format element. void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a literal. void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a variable. - void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx, - MethodBody &os, bool printQualified = false); + void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os, + bool skipGuard = false); + /// Generate a printer for comma-separated parameters. + void genCommaSeparatedPrinter(ArrayRef params, + FmtContext &ctx, MethodBody &os, + function_ref extra); /// Generate the printer code for a `params` directive. void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a `struct` directive. void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os); + /// Generate the printer code for an optional group. + void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, + MethodBody &os); /// The ODS definition of the attribute or type whose format is being used to /// generate a parser and printer. @@ -232,35 +228,43 @@ // ParserGen //===----------------------------------------------------------------------===// -void AttrOrTypeFormat::genParser(MethodBody &os) { +void DefFormat::genParser(MethodBody &os) { FmtContext ctx; ctx.addSubst("_parser", "parser"); if (isa(def)) ctx.addSubst("_type", "type"); os.indent(); - /// Declare variables to store all of the parameters. Allocated parameters - /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store - /// FailureOr to defer type construction for parameters that are parsed in - /// a loop (parsers return FailureOr anyways). + // Declare variables to store all of the parameters. Allocated parameters + // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store + // FailureOr to defer type construction for parameters that are parsed in + // a loop (parsers return FailureOr anyways). ArrayRef params = def.getParameters(); for (const AttrOrTypeParameter ¶m : params) { - os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n", + os << formatv("::mlir::FailureOr<{0}> _result_{1};\n", param.getCppStorageType(), param.getName()); } - /// Store the initial location of the parser. + // Store the initial location of the parser. ctx.addSubst("_loc", "loc"); os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n" "(void) $_loc;\n", &ctx); - /// Generate call to each parameter parser. + // Generate call to each parameter parser. for (FormatElement *el : elements) genElementParser(el, ctx, os); - /// Generate call to the attribute or type builder. Use the checked getter - /// if one was generated. + // Emit an assert for each mandatory parameter. Triggering an assert means + // the generated parser is incorrect (i.e. there is a bug in this code). + for (const AttrOrTypeParameter ¶m : params) { + if (param.isOptional()) + continue; + os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName()); + } + + // Generate call to the attribute or type builder. Use the checked getter + // if one was generated. if (def.genVerifyDecl()) { os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()", &ctx, def.getCppClassName()); @@ -268,29 +272,38 @@ os << tgfmt("return $0::get($_parser.getContext()", &ctx, def.getCppClassName()); } - for (const AttrOrTypeParameter ¶m : params) - os << formatv(",\n _result_{0}.getValue()", param.getName()); + for (const AttrOrTypeParameter ¶m : params) { + if (param.isOptional()) + os << formatv(",\n _result_{0}.getValueOr({1}())", param.getName(), + param.getCppStorageType()); + else + os << formatv(",\n *_result_{0}", param.getName()); + } os << ");"; } -void AttrOrTypeFormat::genElementParser(FormatElement *el, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx, + MethodBody &os) { if (auto *literal = dyn_cast(el)) return genLiteralParser(literal->getSpelling(), ctx, os); if (auto *var = dyn_cast(el)) - return genVariableParser(var->getParam(), ctx, os); + return genVariableParser(var, ctx, os); if (auto *params = dyn_cast(el)) return genParamsParser(params, ctx, os); if (auto *strct = dyn_cast(el)) return genStructParser(strct, ctx, os); + if (auto *optional = dyn_cast(el)) + return genOptionalGroupParser(optional, ctx, os); llvm_unreachable("unknown format element"); } -void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx, + MethodBody &os, bool isOptional) { os << "// Parse literal '" << value << "'\n"; os << tgfmt("if ($_parser.parse", &ctx); + if (isOptional) + os << "Optional"; if (value.front() == '_' || isalpha(value.front())) { os << "Keyword(\"" << value << "\")"; } else { @@ -312,70 +325,265 @@ .Case("*", "Star") << "()"; } - os << ")\n"; + if (isOptional) + // Leave the `if` unclosed to guard optional groups. + return; // Parser will emit an error - os << " return {};\n"; + os << ") return {};\n"; } -void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m, - FmtContext &ctx, MethodBody &os) { - /// Check for a custom parser. Use the default attribute parser otherwise. +void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx, + MethodBody &os) { + // Check for a custom parser. Use the default attribute parser otherwise. + const AttrOrTypeParameter ¶m = el->getParam(); auto customParser = param.getParser(); auto parser = customParser ? *customParser : StringRef(defaultParameterParser); os << formatv(variableParser, param.getName(), tgfmt(parser, &ctx, param.getCppStorageType()), - tgfmt(parseErrorStr, &ctx), def.getName(), param.getCppType()); + tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType()); } -void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, + MethodBody &os) { os << "// Parse parameter list\n"; - llvm::interleave( - el->getParams(), - [&](auto param) { this->genVariableParser(param, ctx, os); }, - [&]() { this->genLiteralParser(",", ctx, os); }); + + // If there are optional parameters, we need to switch to `parseOptionalComma` + // if there are no more required parameters after a certain point. + bool hasOptional = el->hasOptionalParams(); + if (hasOptional) { + // Wrap everything in a do-while so that we can `break`. + os << "do {\n"; + os.indent(); + } + + auto params = el->getParams(); + auto it = params.begin(); + using IteratorT = decltype(it); + + auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); }; + auto betweenFn = [&](IteratorT it, IteratorT end) { + ParameterElement *el = *std::prev(it); + // Parse a comma if the last optional parameter had a value. + if (el->isOptional()) { + os << formatv("if (::mlir::succeeded(_result_{0}) && *_result_{0}) {{\n", + el->getName()); + os.indent(); + } + if (std::any_of(it, end, paramNotOptional)) { + genLiteralParser(",", ctx, os); + } else { + genLiteralParser(",", ctx, os, /*isOptional=*/true); + os << ") break;\n"; + } + if (el->isOptional()) + os.unindent() << "}\n"; + }; + + // llvm::interleave + if (it != params.end()) { + eachFn(*it++); + for (auto e = params.end(); it != e; ++it) { + betweenFn(it, e); + eachFn(*it); + } + } + + if (hasOptional) + os.unindent() << "} while(false);\n"; } -void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, + MethodBody &os) { + // Loop declaration for struct parser with only required parameters. + // + // $0: Number of expected parameters. + const char *const loopHeader = R"( + for (unsigned _index = 0; _index < $0; ++_index) { +)"; + + // Loop body start for struct parser. + const char *const loopStart = R"( + ::llvm::StringRef _paramKey; + if ($_parser.parseKeyword(&_paramKey)) { + $_parser.emitError($_parser.getCurrentLocation(), + "expected a parameter name in struct"); + return {}; + } + if (!_loop_body(_paramKey)) return {}; +)"; + + // Struct parser loop end. Check for duplicate or unknown struct parameters. + // + // {0}: Code template for printing an error. + const char *const loopEnd = R"({{ + {0}"duplicate or unknown struct parameter name: ") << _paramKey; + return {{}; +} +)"; + + // Struct parser loop terminator. Parse a comma except on the last element. + // + // {0}: Number of elements in the struct. + const char *const loopTerminator = R"( + if ((_index != {0} - 1) && parser.parseComma()) + return {{}; +} +)"; + + // Check that a mandatory parameter was parse. + // + // {0}: Name of the parameter. + const char *const checkParam = R"( + if (!_seen_{0}) { + {1}"struct is missing required parameter: ") << "{0}"; + return {{}; + } +)"; + + // Optional parameters in a struct must be parsed successfully if the + // keyword is present. + // + // {0}: Name of the parameter. + // {1}: Emit error string + const char *const checkOptionalParam = R"( + if (::mlir::succeeded(_result_{0}) && !*_result_{0}) {{ + {1}"expected a value for parameter '{0}'"); + return {{}; + } +)"; + + // First iteration of the loop parsing an optional struct. + const char *const optionalStructFirst = R"( + ::llvm::StringRef _paramKey; + if (!$_parser.parseOptionalKeyword(&_paramKey)) { + if (!_loop_body(_paramKey)) return {}; + while (!$_parser.parseOptionalComma()) { +)"; + os << "// Parse parameter struct\n"; - /// Declare a "seen" variable for each key. - for (const AttrOrTypeParameter ¶m : el->getParams()) - os << formatv("bool _seen_{0} = false;\n", param.getName()); + // Declare a "seen" variable for each key. + for (ParameterElement *param : el->getParams()) + os << formatv("bool _seen_{0} = false;\n", param->getName()); - /// Generate the parsing loop. - os.getStream().printReindented( - tgfmt(structParseLoopStart, &ctx, el->getNumParams()).str()); - os.indent(); - genLiteralParser("=", ctx, os); - for (const AttrOrTypeParameter ¶m : el->getParams()) { + // Generate the body of the parsing loop inside a lambda. + os << "{\n"; + os.indent() + << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n"; + genLiteralParser("=", ctx, os.indent()); + for (ParameterElement *param : el->getParams()) { os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n" " _seen_{0} = true;\n", - param.getName()); + param->getName()); genVariableParser(param, ctx, os.indent()); + if (param->isOptional()) { + os.getStream().printReindented(strfmt(checkOptionalParam, + param->getName(), + tgfmt(parserErrorStr, &ctx).str())); + } os.unindent() << "} else "; + // Print the check for duplicate or unknown parameter. } + os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx))); + os << "return true;\n"; + os.unindent() << "};\n"; + + // Generate the parsing loop. If optional parameters are present, then the + // parse loop is guarded by commas. + unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional); + if (numOptional) { + // If the struct itself is optional, pull out the first iteration. + if (numOptional == el->getNumParams()) { + os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str()); + os.indent(); + } else { + os << "do {\n"; + } + } else { + os.getStream().printReindented( + tgfmt(loopHeader, &ctx, el->getNumParams()).str()); + } + os.indent(); + os.getStream().printReindented(tgfmt(loopStart, &ctx).str()); os.unindent(); - /// Duplicate or unknown parameter. - os.getStream().printReindented(strfmt( - structParseLoopEnd, tgfmt(parseErrorStr, &ctx), el->getNumParams())); + // Print the loop terminator. For optional parameters, we have to check that + // all mandatory parameters have been parsed. + // The whole struct is optional if all its parameters are optional. + if (numOptional) { + if (numOptional == el->getNumParams()) { + os << "}\n"; + os.unindent() << "}\n"; + } else { + os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx); + for (ParameterElement *param : el->getParams()) { + if (param->isOptional()) + continue; + os.getStream().printReindented( + strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx))); + } + } + } else { + // Because the loop loops N times and each non-failing iteration sets 1 of + // N flags, successfully exiting the loop means that all parameters have + // been seen. `parseOptionalComma` would cause issues with any formats that + // use "struct(...) `,`" beacuse structs aren't sounded by braces. + os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams())); + } + os.unindent() << "}\n"; +} - /// Because the loop loops N times and each non-failing iteration sets 1 of - /// N flags, successfully exiting the loop means that all parameters have been - /// seen. `parseOptionalComma` would cause issues with any formats that use - /// "struct(...) `,`" beacuse structs aren't sounded by braces. +void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, + MethodBody &os) { + ArrayRef elements = + el->getThenElements().drop_front(el->getParseStart()); + + FormatElement *first = elements.front(); + const auto guardOn = [&](auto params) { + os << "if (!("; + llvm::interleave( + params, os, + [&](ParameterElement *el) { + os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})", + el->getName()); + }, + " || "); + os << ")) {\n"; + }; + if (auto *literal = dyn_cast(first)) { + genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true); + os << ") {\n"; + } else if (auto *param = dyn_cast(first)) { + genVariableParser(param, ctx, os); + guardOn(llvm::makeArrayRef(param)); + } else if (auto *params = dyn_cast(first)) { + genParamsParser(params, ctx, os); + guardOn(params->getParams()); + } else { + auto *strct = cast(first); + genStructParser(strct, ctx, os); + guardOn(params->getParams()); + } + os.indent(); + // Generate the parsers for the rest of the elements. + for (FormatElement *element : el->getElseElements()) + genElementParser(element, ctx, os); + os.unindent() << "} else {\n"; + os.indent(); + for (FormatElement *element : elements.drop_front()) + genElementParser(element, ctx, os); + os.unindent() << "}\n"; } //===----------------------------------------------------------------------===// // PrinterGen //===----------------------------------------------------------------------===// -void AttrOrTypeFormat::genPrinter(MethodBody &os) { +void DefFormat::genPrinter(MethodBody &os) { FmtContext ctx; ctx.addSubst("_printer", "printer"); + os.indent(); /// Generate printers. shouldEmitSpace = true; @@ -384,8 +592,8 @@ genElementPrinter(el, ctx, os); } -void AttrOrTypeFormat::genElementPrinter(FormatElement *el, FmtContext &ctx, - MethodBody &os) { +void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx, + MethodBody &os) { if (auto *literal = dyn_cast(el)) return genLiteralPrinter(literal->getSpelling(), ctx, os); if (auto *params = dyn_cast(el)) @@ -393,63 +601,147 @@ if (auto *strct = dyn_cast(el)) return genStructPrinter(strct, ctx, os); if (auto *var = dyn_cast(el)) - return genVariablePrinter(var->getParam(), ctx, os, - var->shouldBeQualified()); + return genVariablePrinter(var, ctx, os); + if (auto *optional = dyn_cast(el)) + return genOptionalGroupPrinter(optional, ctx, os); - llvm_unreachable("unknown format element"); + llvm::PrintFatalError("unsupported format element"); } -void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, - MethodBody &os) { - /// Don't insert a space before certain punctuation. +void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, + MethodBody &os) { + // Don't insert a space before certain punctuation. bool needSpace = shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation); - os << tgfmt(" $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "", + os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "", value); - /// Update the flags. + // Update the flags. shouldEmitSpace = value.size() != 1 || !StringRef("<({[").contains(value.front()); lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); } -void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m, - FmtContext &ctx, MethodBody &os, - bool printQualified) { - /// Insert a space before the next parameter, if necessary. +void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx, + MethodBody &os, bool skipGuard) { + const AttrOrTypeParameter ¶m = el->getParam(); + ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); + + // Guard the printer on the presence of optional parameters. + if (el->isOptional() && !skipGuard) { + os << tgfmt("if ($_self) {\n", &ctx); + os.indent(); + } + + // Insert a space before the next parameter, if necessary. if (shouldEmitSpace || !lastWasPunctuation) - os << tgfmt(" $_printer << ' ';\n", &ctx); + os << tgfmt("$_printer << ' ';\n", &ctx); shouldEmitSpace = true; lastWasPunctuation = false; - ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); - os << " "; - if (printQualified) + if (el->shouldBeQualified()) os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n"; else if (auto printer = param.getPrinter()) os << tgfmt(*printer, &ctx) << ";\n"; else os << tgfmt(defaultParameterPrinter, &ctx) << ";\n"; + + if (el->isOptional() && !skipGuard) + os.unindent() << "}\n"; } -void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, - MethodBody &os) { - llvm::interleave( - el->getParams(), - [&](auto param) { this->genVariablePrinter(param, ctx, os); }, - [&] { this->genLiteralPrinter(",", ctx, os); }); +void DefFormat::genCommaSeparatedPrinter( + ArrayRef params, FmtContext &ctx, MethodBody &os, + function_ref extra) { + // Emit a space if necessary, but only if the struct is present. + if (shouldEmitSpace || !lastWasPunctuation) { + bool allOptional = llvm::all_of(params, paramIsOptional); + if (allOptional) { + os << "if ("; + llvm::interleave( + params, os, + [&](ParameterElement *param) { + os << getParameterAccessorName(param->getName()) << "()"; + }, + " || "); + os << ") {\n"; + os.indent(); + } + os << tgfmt("$_printer << ' ';\n", &ctx); + if (allOptional) + os.unindent() << "}\n"; + } + + // The first printed element does not need to emit a comma. + os << "{\n"; + os.indent() << "bool _firstPrinted = true;\n"; + for (ParameterElement *param : params) { + if (param->isOptional()) { + os << tgfmt("if ($_self()) {\n", + &ctx.withSelf(getParameterAccessorName(param->getName()))); + os.indent(); + } + os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx); + os << "_firstPrinted = false;\n"; + extra(param); + shouldEmitSpace = false; + lastWasPunctuation = true; + genVariablePrinter(param, ctx, os); + if (param->isOptional()) + os.unindent() << "}\n"; + } + os.unindent() << "}\n"; +} + +void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, + MethodBody &os) { + genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os, + [&](ParameterElement *param) {}); } -void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, +void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, + MethodBody &os) { + genCommaSeparatedPrinter( + llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) { + os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName()); + }); +} + +void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, MethodBody &os) { - llvm::interleave( - el->getParams(), - [&](auto param) { - this->genLiteralPrinter(param.getName(), ctx, os); - this->genLiteralPrinter("=", ctx, os); - this->genVariablePrinter(param, ctx, os); - }, - [&] { this->genLiteralPrinter(",", ctx, os); }); + // Emit the check on whether the group should be printed. + const auto guardOn = [&](auto params) { + os << "if ("; + llvm::interleave( + params, os, + [&](ParameterElement *el) { + os << getParameterAccessorName(el->getName()) << "()"; + }, + " || "); + os << ") {\n"; + os.indent(); + }; + FormatElement *anchor = el->getAnchor(); + if (auto *param = dyn_cast(anchor)) { + guardOn(llvm::makeArrayRef(param)); + } else if (auto *params = dyn_cast(anchor)) { + guardOn(params->getParams()); + } else { + auto *strct = dyn_cast(anchor); + guardOn(strct->getParams()); + } + // Generate the printer for the contained elements. + { + llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace), + lastWasPunctuationFlag(lastWasPunctuation); + for (FormatElement *element : el->getThenElements()) + genElementPrinter(element, ctx, os); + } + os.unindent() << "} else {\n"; + os.indent(); + for (FormatElement *element : el->getElseElements()) + genElementPrinter(element, ctx, os); + os.unindent() << "}\n"; } //===----------------------------------------------------------------------===// @@ -464,7 +756,7 @@ seenParams(def.getNumParameters()) {} /// Parse the attribute or type format and create the format elements. - FailureOr parse(); + FailureOr parse(); protected: /// Verify the parsed elements. @@ -480,9 +772,7 @@ LogicalResult verifyOptionalGroupElements(llvm::SMLoc loc, ArrayRef elements, - Optional anchorIndex) override { - return emitError(loc, "optional groups not (yet) supported"); - } + Optional anchorIndex) override; /// Parse an attribute or type variable. FailureOr parseVariableImpl(llvm::SMLoc loc, StringRef name, @@ -496,8 +786,8 @@ /// Parse a `params` directive. FailureOr parseParamsDirective(llvm::SMLoc loc); /// Parse a `qualified` directive. - FailureOr - parseQualifiedDirective(llvm::SMLoc loc, Context ctx); + FailureOr parseQualifiedDirective(llvm::SMLoc loc, + Context ctx); /// Parse a `struct` directive. FailureOr parseStructDirective(llvm::SMLoc loc); @@ -511,30 +801,77 @@ LogicalResult DefFormatParser::verify(llvm::SMLoc loc, ArrayRef elements) { + // Check that all parameters are referenced in the format. for (auto &it : llvm::enumerate(def.getParameters())) { - if (!seenParams.test(it.index())) { + if (!it.value().isOptional() && !seenParams.test(it.index())) { return emitError(loc, "format is missing reference to parameter: " + - it.value().getName()); + it.value().getName()); + } + } + // A `struct` directive that contains optional parameters cannot be followed + // by a comma literal, which is ambiguous. + for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) { + auto *structEl = dyn_cast(std::get<0>(it)); + auto *literalEl = dyn_cast(std::get<1>(it)); + if (!structEl || !literalEl) + continue; + if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) { + return emitError(loc, "`struct` directive with optional parameters " + "cannot be followed by a comma literal"); + } + } + return success(); +} + +LogicalResult +DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, + ArrayRef elements, + Optional anchorIndex) { + // `params` and `struct` directives are allowed only if all the contained + // parameters are optional. + for (FormatElement *el : elements) { + if (auto *param = dyn_cast(el)) { + if (!param->isOptional()) { + return emitError(loc, + "parameters in an optional group must be optional"); + } + } else if (auto *params = dyn_cast(el)) { + if (llvm::any_of(params->getParams(), paramNotOptional)) { + return emitError(loc, "`params` directive allowed in optional group " + "only if all parameters are optional"); + } + } else if (auto *strct = dyn_cast(el)) { + if (llvm::any_of(strct->getParams(), paramNotOptional)) { + return emitError(loc, "`struct` is only allowed in an optional group " + "if all captured parameters are optional"); + } } } + // The anchor must be a parameter or one of the aforementioned directives. + if (anchorIndex && !isa( + elements[*anchorIndex])) { + return emitError(loc, + "optional group anchor must be a parameter or directive"); + } return success(); } -FailureOr DefFormatParser::parse() { +FailureOr DefFormatParser::parse() { FailureOr> elements = FormatParser::parse(); - if (failed(elements)) return failure(); - return AttrOrTypeFormat(def, std::move(*elements)); + if (failed(elements)) + return failure(); + return DefFormat(def, std::move(*elements)); } FailureOr DefFormatParser::parseVariableImpl(llvm::SMLoc loc, StringRef name, Context ctx) { - /// Lookup the parameter. + // Lookup the parameter. ArrayRef params = def.getParameters(); auto *it = llvm::find_if( params, [&](auto ¶m) { return param.getName() == name; }); - /// Check that the parameter reference is valid. + // Check that the parameter reference is valid. if (it == params.end()) { return emitError(loc, def.getName() + " has no parameter named '" + name + "'"); @@ -558,7 +895,8 @@ return parseParamsDirective(loc); case FormatToken::kw_struct: if (ctx != TopLevelContext) { - return emitError(loc, + return emitError( + loc, "`struct` may only be used in the top-level section of the format"); } return parseStructDirective(loc); @@ -587,13 +925,13 @@ FailureOr DefFormatParser::parseParamsDirective(llvm::SMLoc loc) { - /// Collect all of the attribute's or type's parameters. + // Collect all of the attribute's or type's parameters. std::vector vars; - /// Ensure that none of the parameters have already been captured. + // Ensure that none of the parameters have already been captured. for (const auto &it : llvm::enumerate(def.getParameters())) { if (seenParams.test(it.index())) { return emitError(loc, "`params` captures duplicate parameter: " + - it.value().getName()); + it.value().getName()); } seenParams.set(it.index()); vars.push_back(create(it.value())); @@ -607,17 +945,17 @@ "expected '(' before `struct` argument list"))) return failure(); - /// Parse variables captured by `struct`. + // Parse variables captured by `struct`. std::vector vars; - /// Parse first captured parameter or a `params` directive. + // Parse first captured parameter or a `params` directive. FailureOr var = parseElement(StructDirectiveContext); if (failed(var) || !isa(*var)) { return emitError(loc, "`struct` argument list expected a variable or directive"); } if (isa(*var)) { - /// Parse any other parameters. + // Parse any other parameters. vars.push_back(std::move(*var)); while (peekToken().is(FormatToken::comma)) { consumeToken(); @@ -627,12 +965,12 @@ vars.push_back(std::move(*var)); } } else { - /// `struct(params)` captures all parameters in the attribute or type. + // `struct(params)` captures all parameters in the attribute or type. vars = cast(*var)->takeParams(); } if (failed(parseToken(FormatToken::r_paren, - "expected ')' at the end of an argument list"))) + "expected ')' at the end of an argument list"))) return failure(); return create(std::move(vars)); @@ -650,16 +988,16 @@ llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), llvm::SMLoc()); - /// Parse the custom assembly format> + // Parse the custom assembly format> DefFormatParser fmtParser(mgr, def); - FailureOr format = fmtParser.parse(); + FailureOr format = fmtParser.parse(); if (failed(format)) { if (formatErrorIsFatal) PrintFatalError(def.getLoc(), "failed to parse assembly format"); return; } - /// Generate the parser and printer. + // Generate the parser and printer. format->genParser(parser); format->genPrinter(printer); }