diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -105,6 +105,24 @@ } }; +/// Parse an Optional attribute. +template +struct FieldParser< + std::optional, + std::enable_if_t::value, + std::optional>> { + static FailureOr> parse(AsmParser &parser) { + AttributeT attr; + OptionalParseResult result = parser.parseOptionalAttribute(attr); + if (result.has_value()) { + if (succeeded(*result)) + return {std::optional(attr)}; + return failure(); + } + return {std::nullopt}; + } +}; + /// Parse an Optional integer. template struct FieldParser< diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -218,9 +218,14 @@ } def TestTypeOptionalParam : Test_Type<"TestTypeOptionalParam"> { - let parameters = (ins OptionalParameter<"mlir::Optional">:$a, "int":$b); + let parameters = (ins + OptionalParameter<"mlir::Optional">:$a, + "int":$b, + DefaultValuedParameter<"std::optional<::mlir::Attribute>", + "std::nullopt">:$c + ); let mnemonic = "optional_param"; - let assemblyFormat = "`<` $a `,` $b `>`"; + let assemblyFormat = "`<` $a `,` $b ( `,` $c^)? `>`"; } def TestTypeOptionalParams : Test_Type<"TestTypeOptionalParams"> { diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -34,6 +34,8 @@ // CHECK: !test.struct_capture_all // CHECK: !test.optional_param<, 6> // CHECK: !test.optional_param<5, 6> +// CHECK: !test.optional_param<5, 6, "foo"> +// CHECK: !test.optional_param<5, 6, {foo = "bar"}> // CHECK: !test.optional_params<"a"> // CHECK: !test.optional_params<5, "a"> // CHECK: !test.optional_struct @@ -72,6 +74,8 @@ !test.struct_capture_all, !test.optional_param<, 6>, !test.optional_param<5, 6>, + !test.optional_param<5, 6, "foo">, + !test.optional_param<5, 6, {foo = "bar"}>, !test.optional_params<"a">, !test.optional_params<5, "a">, !test.optional_struct,