diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -382,3 +382,172 @@ Aside from that, all of the interfaces for uniquing and storage construction are all the same. + +## Defining Custom Parsers and Printers using Assembly Formats + +Attributes and types defined in ODS with a mnemonic can define an +`assemblyFormat` to declaratively describe custom parsers and printers. The +assembly format consists of literals, variables, and directives. + +* A literal is a keyword or valid punctuation enclosed in backticks, e.g. + `` `keyword` `` or `` `<` ``. +* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`, + which captures one attribute or type parameter. +* A directive is a keyword followed by an optional argument list that defines + special parser and printer behaviour. + +```tablegen +// An example type with an assembly format. +def MyType : TypeDef { + // Define a mnemonic to allow the dialect's parser hook to call into the + // generated parser. + let mnemonic = "my_type"; + + // Define two parameters whose C++ types are indicated in string literals. + let parameters = (ins "int":$count, "AffineMap":$map); + + // Define the assembly format. Surround the format with less `<` and greater + // `>` so that MLIR's printers use the pretty format. + let assemblyFormat = "`<` $count `,` `map` `=` $map `>`"; +} +``` + +The declarative assembly format for `MyType` results in the following format +in the IR: + +```mlir +!my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)> +``` + +### Parameter Parsing and Printing + +For many basic parameter types, no additional work is needed to define how +these parameters are parsed or printerd. + +* The default printer for any parameter is `$_printer << $_self`, + where `$_self` is the C++ value of the parameter and `$_printer` is a + `DialectAsmPrinter`. +* The default parser for a parameter is + `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type + of the parameter and `$_parser` is a `DialectAsmParser`. + +Printing and parsing behaviour can be added to additional C++ types by +overloading these functions or by defining a `parser` and `printer` in an ODS +parameter class. + +Example of overloading: + +```c++ +using MyParameter = std::pair; + +DialectAsmPrinter &operator<<(DialectAsmPrinter &printer, MyParameter param) { + printer << param.first << " * " << param.second; +} + +template <> struct FieldParser { + static FailureOr parse(DialectAsmParser &parser) { + int a, b; + if (parser.parseInteger(a) || parser.parseStar() || + parser.parseInteger(b)) + return failure(); + return MyParameter(a, b); + } +}; +``` + +Example of using ODS parameter classes: + +``` +def MyParameter : TypeParameter<"std::pair", "pair of ints"> { + let printer = [{ $_printer << $_self.first << " * " << $_self.second }]; + let parser = [{ [&] -> FailureOr> { + int a, b; + if ($_parser.parseInteger(a) || $_parser.parseStar() || + $_parser.parseInteger(b)) + return failure(); + return std::make_pair(a, b); + }() }]; +} +``` + +A type using this parameter with the assembly format `` `<` $myParam `>` `` +will look as follows in the IR: + +```mlir +!my_dialect.my_type<42 * 24> +``` + +#### Non-POD Parameters + +Parameters that aren't plain-old-data (e.g. references) may need to define a +`cppStorageType` to contain the data until it is copied into the allocator. +For example, `StringRefParameter` uses `std::string` as its storage type, +whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers +for these parameters are expected to return `FailureOr<$cppStorageType>`. + +### Assembly Format Directives + +Attribute and type assembly formats have the following directives: + +* `params`: capture all parameters of an attribute or type. +* `struct`: generate a "struct-like" parser and printer for a list of key-value + pairs. + +#### `params` Directive + +This directive is used to refer to all parameters of an attribute or type. +When used as a top-level directive, `params` generates a parser and printer for +a comma-separated list of the parameters. For example: + +```tablegen +def MyPairType : TypeDef { + let parameters = (ins "int":$a, "int":$b); + let mnemonic = "pair"; + let assemblyFormat = "`<` params `>`"; +} +``` + +In the IR, this type will appear as: + +```mlir +!my_dialect.pair<42, 24> +``` + +The `params` directive can also be passed to other directives, such as `struct`, +as an argument that refers to all parameters in place of explicitly listing all +parameters as variables. + +#### `struct` Directive + +The `struct` directive accepts a list of variables to capture and will generate +a parser and printer for a comma-separated list of key-value pairs. The +variables are printed in the order they are specified in the argument list **but +can be parsed in any order**. For example: + +```tablegen +def MyStructType : TypeDef { + let parameters = (ins StringRefParameter<>:$sym_name, + "int":$a, "int":$b, "int":$c); + let mnemonic = "struct"; + let assemblyFormat = "`<` $sym_name `->` struct($a, $b, $c) `>`"; +} +``` + +In the IR, this type can appear with any permutation of the order of the +parameters captured in the directive. + +```mlir +!my_dialect.struct<"foo" -> a = 1, b = 2, c = 3> +!my_dialect.struct<"foo" -> b = 2, c = 3, a = 1> +``` + +Passing `params` as the only argument to `struct` makes the directive capture +all the parameters of the attribute or type. For the same type above, an +assembly format of `` `<` struct(params) `>` `` will result in: + +```mlir +!my_dialect.struct +``` + +The order in which the parameters are printed is the order in which they are +declared in the attribute's or type's `parameter` list. diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -47,6 +47,74 @@ virtual StringRef getFullSymbolSpec() const = 0; }; +//===----------------------------------------------------------------------===// +// Parse Fields +//===----------------------------------------------------------------------===// + +/// Provide a template class that can be specialized by users to dispatch to +/// parsers. Auto-generated parsers generate calls to `FieldParser::parse`, +/// where `T` is the parameter storage type, to parse custom types. +template +struct FieldParser; + +/// Parse an attribute. +template +struct FieldParser< + AttributeT, std::enable_if_t::value, + AttributeT>> { + static FailureOr parse(DialectAsmParser &parser) { + AttributeT value; + if (parser.parseAttribute(value)) + return failure(); + return value; + } +}; + +/// Parse any integer. +template +struct FieldParser::value, IntT>> { + static FailureOr parse(DialectAsmParser &parser) { + IntT value; + if (parser.parseInteger(value)) + return failure(); + return value; + } +}; + +/// Parse a string. +template <> +struct FieldParser { + static FailureOr parse(DialectAsmParser &parser) { + std::string value; + if (parser.parseString(&value)) + return failure(); + return value; + } +}; + +/// Parse any container that supports back insertion as a list. +template +struct FieldParser< + ContainerT, std::enable_if_t::value, + ContainerT>> { + using ElementT = typename ContainerT::value_type; + static FailureOr parse(DialectAsmParser &parser) { + ContainerT elements; + auto elementParser = [&]() { + auto element = FieldParser::parse(parser); + if (failed(element)) + return failure(); + elements.push_back(element.getValue()); + return success(); + }; + if (parser.parseCommaSeparatedList(elementParser)) + return failure(); + return elements; + } +}; + } // end namespace mlir -#endif +#endif // MLIR_IR_DIALECTIMPLEMENTATION_H diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2876,6 +2876,11 @@ code printer = ?; code parser = ?; + // Custom assembly format. Requires 'mnemonic' to be specified. Cannot be + // specified at the same time as either 'printer' or 'parser'. The generated + // printer requires 'genAccessors' to be true. + string assemblyFormat = ?; + // If set, generate accessors for each parameter. bit genAccessors = 1; @@ -2954,10 +2959,22 @@ string cppType = type; // The C++ type of the accessor for this parameter. string cppAccessorType = !if(!empty(accessorType), type, accessorType); + // The C++ storage type of of this parameter if it is a reference, e.g. + // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`. + string cppStorageType = ?; // One-line human-readable description of the argument. string summary = desc; // The format string for the asm syntax (documentation only). string syntax = ?; + // The default parameter parser is `::mlir::parseField($_parser)`, which + // returns `FailureOr`. Overload `parseField` to support parsing for your + // type. Or you can provide a customer printer. For attributes, "$_type" will + // be replaced with the required attribute type. + string parser = ?; + // The default parameter printer is `$_printer << $_self`. Overload the stream + // operator of `DialectAsmPrinter` as necessary to print your type. Or you can + // provide a custom printer. + string printer = ?; } class AttrParameter : AttrOrTypeParameter; @@ -2968,6 +2985,8 @@ class StringRefParameter : AttrOrTypeParameter<"::llvm::StringRef", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; + let printer = [{$_printer << '"' << $_self << '"';}]; + let cppStorageType = "std::string"; } // For APFloats, which require comparison. @@ -2980,6 +2999,7 @@ class ArrayRefParameter : AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; + let cppStorageType = "::llvm::SmallVector<" # arrayOf # ">"; } // For classes which require allocation and have their own allocateInto method. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -182,10 +182,10 @@ llvm::interleaveComma(types, p); return p; } -template +template inline std::enable_if_t::value, AsmPrinterT &> -operator<<(AsmPrinterT &p, ArrayRef types) { +operator<<(AsmPrinterT &p, ArrayRef types) { llvm::interleaveComma(types, p); return p; } diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -101,6 +101,9 @@ // None. Otherwise, returns the contents of that code block. Optional getParserCode() const; + // Returns the custom assembly format, if one was specified. + Optional getAssemblyFormat() const; + // Returns true if the accessors based on the parameters should be generated. bool genAccessors() const; @@ -199,6 +202,15 @@ // Get the C++ accessor type of this parameter. StringRef getCppAccessorType() const; + // Get the C++ storage type of this parameter. + StringRef getCppStorageType() const; + + // Get an optional C++ parameter parser. + Optional getParser() const; + + // Get an optional C++ parameter printer. + Optional getPrinter() const; + // Get a description of this parameter for documentation purposes. Optional getSummary() const; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -1,3 +1,4 @@ +//===- Dialect.h - Dialect class --------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -132,6 +132,10 @@ return def->getValueAsOptionalString("parser"); } +Optional AttrOrTypeDef::getAssemblyFormat() const { + return def->getValueAsOptionalString("assemblyFormat"); +} + bool AttrOrTypeDef::genAccessors() const { return def->getValueAsBit("genAccessors"); } @@ -219,6 +223,32 @@ return getCppType(); } +StringRef AttrOrTypeParameter::getCppStorageType() const { + if (auto *param = dyn_cast(def->getArg(index))) { + if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType")) + return *type; + } + return getCppType(); +} + +Optional AttrOrTypeParameter::getParser() const { + auto *parameterType = def->getArg(index); + if (auto *param = dyn_cast(parameterType)) { + if (auto parser = param->getDef()->getValueAsOptionalString("parser")) + return *parser; + } + return {}; +} + +Optional AttrOrTypeParameter::getPrinter() const { + auto *parameterType = def->getArg(index); + if (auto *param = dyn_cast(parameterType)) { + if (auto printer = param->getDef()->getValueAsOptionalString("printer")) + return *printer; + } + return {}; +} + Optional AttrOrTypeParameter::getSummary() const { auto *parameterType = def->getArg(index); if (auto *param = dyn_cast(parameterType)) { diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -101,4 +101,44 @@ let genVerifyDecl = 1; } +def TestParamOne : AttrParameter<"int64_t", ""> {} + +def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> { + let printer = "$_printer << '\"' << $_self << '\"'"; +} + +def TestParamFour : ArrayRefParameter<"int", ""> { + let cppStorageType = "llvm::SmallVector"; + let parser = "::parseIntArray($_parser)"; + let printer = "::printIntArray($_printer, $_self)"; +} + +def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> { + let parameters = ( + ins + TestParamOne:$one, + TestParamTwo:$two, + "::mlir::IntegerAttr":$three, + TestParamFour:$four + ); + + let mnemonic = "attr_with_format"; + let assemblyFormat = "`<` $one `:` struct($two, $four) `:` $three `>`"; + let genVerifyDecl = 1; +} + +def TestAttrUgly : Test_Attr<"TestAttrUgly"> { + let parameters = (ins "::mlir::Attribute":$attr); + + let mnemonic = "attr_ugly"; + let assemblyFormat = "`begin` $attr `end`"; +} + +def TestAttrParams: Test_Attr<"TestAttrParams"> { + let parameters = (ins "int":$v0, "int":$v1); + + let mnemonic = "attr_params"; + let assemblyFormat = "`<` params `>`"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -16,9 +16,11 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/bit.h" using namespace mlir; using namespace test; @@ -127,6 +129,36 @@ return success(); } +LogicalResult +TestAttrWithFormatAttr::verify(function_ref emitError, + int64_t one, std::string two, IntegerAttr three, + ArrayRef four) { + if (four.size() != static_cast(one)) + return emitError() << "expected 'one' to equal 'four.size()'"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Utility Functions for Generated Attributes +//===----------------------------------------------------------------------===// + +static FailureOr> parseIntArray(DialectAsmParser &parser) { + SmallVector ints; + if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() { + ints.push_back(0); + return parser.parseInteger(ints.back()); + }) || + parser.parseRSquare()) + return failure(); + return ints; +} + +static void printIntArray(DialectAsmPrinter &printer, ArrayRef ints) { + printer << '['; + llvm::interleaveComma(ints, printer); + printer << ']'; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -15,6 +15,7 @@ // To get the test dialect def. include "TestOps.td" +include "TestAttrDefs.td" include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" @@ -189,4 +190,44 @@ let mnemonic = "test_type_with_trait"; } +// Type with assembly format. +def TestTypeWithFormat : Test_Type<"TestTypeWithFormat"> { + let parameters = ( + ins + TestParamOne:$one, + TestParamTwo:$two, + "::mlir::Attribute":$three + ); + + let mnemonic = "type_with_format"; + let assemblyFormat = "`<` $one `,` struct($three, $two) `>`"; +} + +// Test dispatch to parseField +def TestTypeNoParser : Test_Type<"TestTypeNoParser"> { + let parameters = ( + ins + "uint32_t":$one, + ArrayRefParameter<"int64_t">:$two, + StringRefParameter<>:$three, + "::test::CustomParam":$four + ); + + let mnemonic = "no_parser"; + let assemblyFormat = "`<` $one `,` `[` $two `]` `,` $three `,` $four `>`"; +} + +def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> { + let parameters = ( + ins + "int":$v0, + "int":$v1, + "int":$v2, + "int":$v3 + ); + + let mnemonic = "struct_capture_all"; + let assemblyFormat = "`<` struct(params) `>`"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -38,8 +38,38 @@ } }; +/// A custom type for a test type parameter. +struct CustomParam { + int value; + + bool operator==(const CustomParam &other) const { + return other.value == value; + } +}; + +inline llvm::hash_code hash_value(const test::CustomParam ¶m) { + return llvm::hash_value(param.value); +} + } // namespace test +namespace mlir { +template <> +struct FieldParser { + static FailureOr parse(DialectAsmParser &parser) { + auto value = FieldParser::parse(parser); + if (failed(value)) + return failure(); + return test::CustomParam{value.getValue()}; + } +}; +} // end namespace mlir + +inline mlir::DialectAsmPrinter &operator<<(mlir::DialectAsmPrinter &printer, + const test::CustomParam ¶m) { + return printer << param.value; +} + #include "TestTypeInterfaces.h.inc" #define GET_TYPEDEF_CLASSES @@ -52,17 +82,19 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage { using KeyTy = ::llvm::StringRef; - explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key), body(::mlir::Type()) {} + explicit TestRecursiveTypeStorage(::llvm::StringRef key) + : name(key), body(::mlir::Type()) {} bool operator==(const KeyTy &other) const { return name == other; } - static TestRecursiveTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, - const KeyTy &key) { + static TestRecursiveTypeStorage * + construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) { return new (allocator.allocate()) TestRecursiveTypeStorage(allocator.copyInto(key)); } - ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, ::mlir::Type newBody) { + ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, + ::mlir::Type newBody) { // Cannot set a different body than before. if (body && body != newBody) return ::mlir::failure(); @@ -79,11 +111,13 @@ /// type, potentially itself. This requires the body to be mutated separately /// from type creation. class TestRecursiveType - : public ::mlir::Type::TypeBase { + : public ::mlir::Type::TypeBase { public: using Base::Base; - static TestRecursiveType get(::mlir::MLIRContext *ctx, ::llvm::StringRef name) { + static TestRecursiveType get(::mlir::MLIRContext *ctx, + ::llvm::StringRef name) { return Base::get(ctx, name); } diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -0,0 +1,76 @@ +// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include -asmformat-error-is-fatal=false %s 2>&1 | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "TestDialect"; + let cppNamespace = "::test"; +} + +class InvalidType : TypeDef { + let mnemonic = asm; +} + +/// Test format is missing a parameter capture. +def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> { + let parameters = (ins "int":$v0, "int":$v1); + // CHECK: format is missing reference to parameter: v1 + let assemblyFormat = "`<` $v0 `>`"; +} + +/// Test format has duplicate parameter captures. +def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> { + let parameters = (ins "int":$v0, "int":$v1); + // CHECK: duplicate parameter 'v0' + let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`"; +} + +/// Test format has invalid syntax. +def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> { + let parameters = (ins "int":$v0, "int":$v1); + // CHECK: expected literal, directive, or variable + let assemblyFormat = "`<` $v0, $v1 `>`"; +} + +/// Test struct directive has invalid syntax. +def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> { + let parameters = (ins "int":$v0); + // CHECK: literals may only be used in the top-level section of the format + // CHECK: expected a variable in `struct` argument list + let assemblyFormat = "`<` struct($v0, `,`) `>`"; +} + +/// Test struct directive cannot capture zero parameters. +def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> { + let parameters = (ins "int":$v0); + // CHECK: `struct` argument list expected a variable or directive + let assemblyFormat = "`<` struct() $v0 `>`"; +} + +/// Test capture parameter that does not exist. +def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> { + let parameters = (ins "int":$v0); + // CHECK: InvalidTypeF has no parameter named 'v1' + let assemblyFormat = "`<` $v0 $v1 `>`"; +} + +/// Test duplicate capture of parameter in capture-all struct. +def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> { + let parameters = (ins "int":$v0, "int":$v1, "int":$v2); + // CHECK: duplicate parameter 'v0' + let assemblyFormat = "`<` struct(params) $v0 `>`"; +} + +/// Test capture-all struct duplicate capture. +def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> { + let parameters = (ins "int":$v0, "int":$v1, "int":$v2); + // CHECK: `params` captures duplicate parameter: v0 + let assemblyFormat = "`<` $v0 struct(params) `>`"; +} + +/// Test capture of parameter after `params` directive. +def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> { + let parameters = (ins "int":$v0); + // CHECK: duplicate parameter 'v0' + let assemblyFormat = "`<` params $v0 `>`"; +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: @test_roundtrip_parameter_parsers +// CHECK: !test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo"> +// CHECK: !test.type_with_format<2147, three = "hi", two = "hi"> +func private @test_roundtrip_parameter_parsers(!test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">) -> !test.type_with_format<2147, two = "hi", three = "hi"> +attributes { + // CHECK: #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64> + attr0 = #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>, + // CHECK: #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8> + attr1 = #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>, + // CHECK: #test<"attr_ugly begin 5 : index end"> + attr2 = #test<"attr_ugly begin 5 : index end">, + // CHECK: #test.attr_params<42, 24> + attr3 = #test.attr_params<42, 24> +} + +// CHECK-LABEL: @test_roundtrip_default_parsers_struct +// CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> +// CHECK: !test.struct_capture_all +func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.mlir b/mlir/test/mlir-tblgen/attr-or-type-format.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attr-or-type-format.mlir @@ -0,0 +1,127 @@ +// RUN: mlir-opt --split-input-file %s --verify-diagnostics + +func private @test_ugly_attr_cannot_be_pretty() -> () attributes { + // expected-error@+1 {{expected 'begin'}} + attr = #test.attr_ugly +} + +// ----- + +func private @test_ugly_attr_no_mnemonic() -> () attributes { + // expected-error@+1 {{expected valid keyword}} + attr = #test<""> +} + +// ----- + +func private @test_ugly_attr_parser_dispatch() -> () attributes { + // expected-error@+1 {{expected 'begin'}} + attr = #test<"attr_ugly"> +} + +// ----- + +func private @test_ugly_attr_missing_parameter() -> () attributes { + // expected-error@+2 {{failed to parse TestAttrUgly parameter 'attr'}} + // expected-error@+1 {{expected non-function type}} + attr = #test<"attr_ugly begin"> +} + +// ----- + +func private @test_ugly_attr_missing_literal() -> () attributes { + // expected-error@+1 {{expected 'end'}} + attr = #test<"attr_ugly begin \"string_attr\""> +} + +// ----- + +func private @test_pretty_attr_expects_less() -> () attributes { + // expected-error@+1 {{expected '<'}} + attr = #test.attr_with_format +} + +// ----- + +func private @test_pretty_attr_missing_param() -> () attributes { + // expected-error@+2 {{expected integer value}} + // expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'one'}} + attr = #test.attr_with_format<> +} + +// ----- + +func private @test_parse_invalid_param() -> () attributes { + // Test parameter parser failure is propagated + // expected-error@+2 {{expected integer value}} + // expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'one'}} + attr = #test.attr_with_format<"hi"> +} + +// ----- + +func private @test_pretty_attr_invalid_syntax() -> () attributes { + // expected-error@+1 {{expected ':'}} + attr = #test.attr_with_format<42> +} + +// ----- + +func private @test_struct_missing_key() -> () attributes { + // expected-error@+2 {{expected valid keyword}} + // expected-error@+1 {{expected a parameter name in struct}} + attr = #test.attr_with_format<42 :> +} + +// ----- + +func private @test_struct_unknown_key() -> () attributes { + // expected-error@+1 {{duplicate or unknown struct parameter}} + attr = #test.attr_with_format<42 : nine = "foo"> +} + +// ----- + +func private @test_struct_duplicate_key() -> () attributes { + // expected-error@+1 {{duplicate or unknown struct parameter}} + attr = #test.attr_with_format<42 : two = "foo", two = "bar"> +} + +// ----- + +func private @test_struct_not_enough_values() -> () attributes { + // expected-error@+1 {{expected ','}} + attr = #test.attr_with_format<42 : two = "foo"> +} + +// ----- + +func private @test_parse_param_after_struct() -> () attributes { + // expected-error@+2 {{expected non-function type}} + // expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'three'}} + attr = #test.attr_with_format<42 : two = "foo", four = [1, 2, 3] : > +} + +// ----- + +// expected-error@+1 {{expected '<'}} +func private @test_invalid_type() -> !test.type_with_format + +// ----- + +// expected-error@+2 {{expected integer value}} +// expected-error@+1 {{failed to parse TestTypeWithFormat parameter 'one'}} +func private @test_pretty_type_invalid_param() -> !test.type_with_format<> + +// ----- + +// expected-error@+2 {{expected ':'}} +// expected-error@+1 {{failed to parse TestTypeWithFormat parameter 'three'}} +func private @test_type_syntax_error() -> !test.type_with_format<42, two = "hi", three = #test.attr_with_format<42>> + +// ----- + +func private @test_verifier_fails() -> () attributes { + // expected-error@+1 {{expected 'one' to equal 'four.size()'}} + attr = #test.attr_with_format<42 : two = "hello", four = [1, 2, 3] : 42 : i64> +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -0,0 +1,394 @@ +// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR +// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE + +include "mlir/IR/OpBase.td" + +/// Test that attribute and type printers and parsers are correctly generated. +def Test_Dialect : Dialect { + let name = "TestDialect"; + let cppNamespace = "::test"; +} + +class TestAttr : AttrDef; +class TestType : TypeDef; + +def AttrParamA : AttrParameter<"TestParamA", "an attribute param A"> { + let parser = "::parseAttrParamA($_parser, $_type)"; + let printer = "::printAttrParamA($_printer, $_self)"; +} + +def AttrParamB : AttrParameter<"TestParamB", "an attribute param B"> { + let parser = "$_type ? ::parseAttrWithType($_parser, $_type) : ::parseAttrWithout($_parser)"; + let printer = "::printAttrB($_printer, $_self)"; +} + +def TypeParamA : TypeParameter<"TestParamC", "a type param C"> { + let parser = "::parseTypeParamC($_parser)"; + let printer = "$_printer << $_self"; +} + +def TypeParamB : TypeParameter<"TestParamD", "a type param D"> { + let parser = "someFcnCall()"; + let printer = "myPrinter($_self)"; +} + +/// Check simple attribute parser and printer are generated correctly. + +// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::DialectAsmParser &parser, +// ATTR: ::mlir::Type attrType) { +// ATTR: FailureOr _result_value; +// ATTR: FailureOr _result_complex; +// ATTR: if (parser.parseKeyword("hello")) +// ATTR: return {}; +// ATTR: if (parser.parseEqual()) +// ATTR: return {}; +// ATTR: _result_value = ::mlir::FieldParser::parse(parser); +// ATTR: if (failed(_result_value)) +// ATTR: return {}; +// ATTR: if (parser.parseComma()) +// ATTR: return {}; +// ATTR: _result_complex = ::parseAttrParamA(parser, attrType); +// ATTR: if (failed(_result_complex)) +// ATTR: return {}; +// ATTR: if (parser.parseRParen()) +// ATTR: return {}; +// ATTR: return TestAAttr::get(parser.getContext(), +// ATTR: _result_value.getValue(), +// ATTR: _result_complex.getValue()); +// ATTR: } + +// ATTR: void TestAAttr::print(::mlir::DialectAsmPrinter &printer) const { +// ATTR: printer << "attr_a"; +// ATTR: printer << ' ' << "hello"; +// ATTR: printer << ' ' << "="; +// ATTR: printer << ' '; +// ATTR: printer << getValue(); +// ATTR: printer << ","; +// ATTR: printer << ' '; +// ATTR: ::printAttrParamA(printer, getComplex()); +// ATTR: printer << ")"; +// ATTR: } + +def AttrA : TestAttr<"TestA"> { + let parameters = (ins + "IntegerAttr":$value, + AttrParamA:$complex + ); + + let mnemonic = "attr_a"; + let assemblyFormat = "`hello` `=` $value `,` $complex `)`"; +} + +/// Test simple struct parser and printer are generated correctly. + +// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::DialectAsmParser &parser, +// ATTR: ::mlir::Type attrType) { +// ATTR: bool _seen_v0 = false; +// ATTR: bool _seen_v1 = false; +// ATTR: for (unsigned _index = 0; _index < 2; ++_index) { +// ATTR: StringRef _paramKey; +// ATTR: if (parser.parseKeyword(&_paramKey)) +// ATTR: return {}; +// ATTR: if (parser.parseEqual()) +// ATTR: return {}; +// ATTR: if (!_seen_v0 && _paramKey == "v0") { +// ATTR: _seen_v0 = true; +// ATTR: _result_v0 = ::parseAttrParamA(parser, attrType); +// ATTR: if (failed(_result_v0)) +// ATTR: return {}; +// ATTR: } else if (!_seen_v1 && _paramKey == "v1") { +// ATTR: _seen_v1 = true; +// ATTR: _result_v1 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser); +// ATTR: if (failed(_result_v1)) +// ATTR: return {}; +// ATTR: } else { +// ATTR: return {}; +// ATTR: } +// ATTR: if ((_index != 2 - 1) && parser.parseComma()) +// ATTR: return {}; +// ATTR: } +// ATTR: return TestBAttr::get(parser.getContext(), +// ATTR: _result_v0.getValue(), +// ATTR: _result_v1.getValue()); +// ATTR: } + +// ATTR: void TestBAttr::print(::mlir::DialectAsmPrinter &printer) const { +// ATTR: printer << "v0"; +// ATTR: printer << ' ' << "="; +// ATTR: printer << ' '; +// ATTR: ::printAttrParamA(printer, getV0()); +// ATTR: printer << ","; +// ATTR: printer << ' ' << "v1"; +// ATTR: printer << ' ' << "="; +// ATTR: printer << ' '; +// ATTR: ::printAttrB(printer, getV1()); +// ATTR: } + +def AttrB : TestAttr<"TestB"> { + let parameters = (ins + AttrParamA:$v0, + AttrParamB:$v1 + ); + + let mnemonic = "attr_b"; + let assemblyFormat = "`{` struct($v0, $v1) `}`"; +} + +/// Test attribute with capture-all params has correct parser and printer. + +// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::DialectAsmParser &parser, +// ATTR: ::mlir::Type attrType) { +// ATTR: ::mlir::FailureOr _result_v0; +// ATTR: ::mlir::FailureOr _result_v1; +// ATTR: _result_v0 = ::mlir::FieldParser::parse(parser); +// ATTR: if (failed(_result_v0)) +// ATTR: return {}; +// ATTR: if (parser.parseComma()) +// ATTR: return {}; +// ATTR: _result_v1 = ::mlir::FieldParser::parse(parser); +// ATTR: if (failed(_result_v1)) +// ATTR: return {}; +// ATTR: return TestFAttr::get(parser.getContext(), +// ATTR: _result_v0.getValue(), +// ATTR: _result_v1.getValue()); +// ATTR: } + +// ATTR: void TestFAttr::print(::mlir::DialectAsmPrinter &printer) const { +// ATTR: printer << "attr_c"; +// ATTR: printer << ' '; +// ATTR: printer << getV0(); +// ATTR: printer << ","; +// ATTR: printer << ' '; +// ATTR: printer << getV1(); +// ATTR: } + +def AttrC : TestAttr<"TestF"> { + let parameters = (ins "int":$v0, "int":$v1); + + let mnemonic = "attr_c"; + let assemblyFormat = "params"; +} + +/// Test type parser and printer that mix variables and struct are generated +/// correctly. + +// TYPE: ::mlir::Type TestCType::parse(::mlir::DialectAsmParser &parser) { +// TYPE: FailureOr _result_value; +// TYPE: FailureOr _result_complex; +// TYPE: if (parser.parseKeyword("foo")) +// TYPE: return {}; +// TYPE: if (parser.parseComma()) +// TYPE: return {}; +// TYPE: if (parser.parseColon()) +// TYPE: return {}; +// TYPE: if (parser.parseKeyword("bob")) +// TYPE: return {}; +// TYPE: if (parser.parseKeyword("bar")) +// TYPE: return {}; +// TYPE: _result_value = ::mlir::FieldParser::parse(parser); +// TYPE: if (failed(_result_value)) +// TYPE: return {}; +// TYPE: bool _seen_complex = false; +// TYPE: for (unsigned _index = 0; _index < 1; ++_index) { +// TYPE: StringRef _paramKey; +// TYPE: if (parser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (!_seen_complex && _paramKey == "complex") { +// TYPE: _seen_complex = true; +// TYPE: _result_complex = ::parseTypeParamC(parser); +// TYPE: if (failed(_result_complex)) +// TYPE: return {}; +// TYPE: } else { +// TYPE: return {}; +// TYPE: } +// TYPE: if ((_index != 1 - 1) && parser.parseComma()) +// TYPE: return {}; +// TYPE: } +// TYPE: if (parser.parseRParen()) +// TYPE: return {}; +// TYPE: } + +// TYPE: void TestCType::print(::mlir::DialectAsmPrinter &printer) const { +// TYPE: printer << "type_c"; +// TYPE: printer << ' ' << "foo"; +// TYPE: printer << ","; +// TYPE: printer << ' ' << ":"; +// TYPE: printer << ' ' << "bob"; +// TYPE: printer << ' ' << "bar"; +// TYPE: printer << ' '; +// TYPE: printer << getValue(); +// TYPE: printer << ' ' << "complex"; +// TYPE: printer << ' ' << "="; +// TYPE: printer << ' '; +// TYPE: printer << getComplex(); +// TYPE: printer << ")"; +// TYPE: } + +def TypeA : TestType<"TestC"> { + let parameters = (ins + "IntegerAttr":$value, + TypeParamA:$complex + ); + + let mnemonic = "type_c"; + let assemblyFormat = "`foo` `,` `:` `bob` `bar` $value struct($complex) `)`"; +} + +/// Test type parser and printer with mix of variables and struct are generated +/// correctly. + +// TYPE: ::mlir::Type TestDType::parse(::mlir::DialectAsmParser &parser) { +// TYPE: _result_v0 = ::parseTypeParamC(parser); +// TYPE: if (failed(_result_v0)) +// TYPE: return {}; +// TYPE: bool _seen_v1 = false; +// TYPE: bool _seen_v2 = false; +// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { +// TYPE: StringRef _paramKey; +// TYPE: if (parser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (parser.parseEqual()) +// TYPE: return {}; +// TYPE: if (!_seen_v1 && _paramKey == "v1") { +// TYPE: _seen_v1 = true; +// TYPE: _result_v1 = someFcnCall(); +// TYPE: if (failed(_result_v1)) +// TYPE: return {}; +// TYPE: } else if (!_seen_v2 && _paramKey == "v2") { +// TYPE: _seen_v2 = true; +// TYPE: _result_v2 = ::parseTypeParamC(parser); +// TYPE: if (failed(_result_v2)) +// TYPE: return {}; +// TYPE: } else { +// TYPE: return {}; +// TYPE: } +// TYPE: if ((_index != 2 - 1) && parser.parseComma()) +// TYPE: return {}; +// TYPE: } +// TYPE: _result_v3 = someFcnCall(); +// TYPE: if (failed(_result_v3)) +// TYPE: return {}; +// TYPE: return TestDType::get(parser.getContext(), +// TYPE: _result_v0.getValue(), +// TYPE: _result_v1.getValue(), +// TYPE: _result_v2.getValue(), +// TYPE: _result_v3.getValue()); +// TYPE: } + +// TYPE: void TestDType::print(::mlir::DialectAsmPrinter &printer) const { +// TYPE: printer << getV0(); +// TYPE: myPrinter(getV1()); +// TYPE: printer << ' ' << "v2"; +// TYPE: printer << ' ' << "="; +// TYPE: printer << ' '; +// TYPE: printer << getV2(); +// TYPE: myPrinter(getV3()); +// TYPE: } + +def TypeB : TestType<"TestD"> { + let parameters = (ins + TypeParamA:$v0, + TypeParamB:$v1, + TypeParamA:$v2, + TypeParamB:$v3 + ); + + let mnemonic = "type_d"; + let assemblyFormat = "`<` `foo` `:` $v0 `,` struct($v1, $v2) `,` $v3 `>`"; +} + +/// Type test with two struct directives has correctly generated parser and +/// printer. + +// TYPE: ::mlir::Type TestEType::parse(::mlir::DialectAsmParser &parser) { +// TYPE: FailureOr _result_v0; +// TYPE: FailureOr _result_v1; +// TYPE: FailureOr _result_v2; +// TYPE: FailureOr _result_v3; +// TYPE: bool _seen_v0 = false; +// TYPE: bool _seen_v2 = false; +// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { +// TYPE: StringRef _paramKey; +// TYPE: if (parser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (parser.parseEqual()) +// TYPE: return {}; +// TYPE: if (!_seen_v0 && _paramKey == "v0") { +// TYPE: _seen_v0 = true; +// TYPE: _result_v0 = ::mlir::FieldParser::parse(parser); +// TYPE: if (failed(_result_v0)) +// TYPE: return {}; +// TYPE: } else if (!_seen_v2 && _paramKey == "v2") { +// TYPE: _seen_v2 = true; +// TYPE: _result_v2 = ::mlir::FieldParser::parse(parser); +// TYPE: if (failed(_result_v2)) +// TYPE: return {}; +// TYPE: } else { +// TYPE: return {}; +// TYPE: } +// TYPE: if ((_index != 2 - 1) && parser.parseComma()) +// TYPE: return {}; +// TYPE: } +// TYPE: bool _seen_v1 = false; +// TYPE: bool _seen_v3 = false; +// TYPE: for (unsigned _index = 0; _index < 2; ++_index) { +// TYPE: StringRef _paramKey; +// TYPE: if (parser.parseKeyword(&_paramKey)) +// TYPE: return {}; +// TYPE: if (parser.parseEqual()) +// TYPE: return {}; +// TYPE: if (!_seen_v1 && _paramKey == "v1") { +// TYPE: _seen_v1 = true; +// TYPE: _result_v1 = ::mlir::FieldParser::parse(parser); +// TYPE: if (failed(_result_v1)) +// TYPE: return {}; +// TYPE: } else if (!_seen_v3 && _paramKey == "v3") { +// TYPE: _seen_v3 = true; +// TYPE: _result_v3 = ::mlir::FieldParser::parse(parser); +// TYPE: if (failed(_result_v3)) +// TYPE: return {}; +// TYPE: } else { +// TYPE: return {}; +// TYPE: } +// TYPE: if ((_index != 2 - 1) && parser.parseComma()) +// TYPE: return {}; +// TYPE: } +// TYPE: return TestEType::get(parser.getContext(), +// TYPE: _result_v0.getValue(), +// TYPE: _result_v1.getValue(), +// TYPE: _result_v2.getValue(), +// TYPE: _result_v3.getValue()); +// TYPE: } + +// TYPE: void TestEType::print(::mlir::DialectAsmPrinter &printer) const { +// TYPE: printer << "v0"; +// TYPE: printer << ' ' << "="; +// TYPE: printer << ' '; +// TYPE: printer << getV0(); +// TYPE: printer << ","; +// TYPE: printer << ' ' << "v2"; +// TYPE: printer << ' ' << "="; +// TYPE: printer << ' '; +// TYPE: printer << getV2(); +// TYPE: printer << "v1"; +// TYPE: printer << ' ' << "="; +// TYPE: printer << ' '; +// TYPE: printer << getV1(); +// TYPE: printer << ","; +// TYPE: printer << ' ' << "v3"; +// TYPE: printer << ' ' << "="; +// TYPE: printer << ' '; +// TYPE: printer << getV3(); +// TYPE: } + +def TypeC : TestType<"TestE"> { + let parameters = (ins + "IntegerAttr":$v0, + "IntegerAttr":$v1, + "IntegerAttr":$v2, + "IntegerAttr":$v3 + ); + + let mnemonic = "type_e"; + let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`"; +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "AttrOrTypeFormatGen.h" #include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/CodeGenHelpers.h" @@ -24,6 +25,17 @@ using namespace mlir; using namespace mlir::tblgen; +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +std::string mlir::tblgen::getParameterAccessorName(StringRef name) { + assert(!name.empty() && "parameter has empty name"); + auto ret = "get" + name.str(); + ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name + return ret; +} + /// Find all the AttrOrTypeDef for the specified dialect. If no dialect /// specified and can only find one dialect's defs, use that. static void collectAllDefs(StringRef selectedDialect, @@ -399,7 +411,8 @@ << " }\n"; // If mnemonic specified, emit print/parse declarations. - if (def.getParserCode() || def.getPrinterCode() || !params.empty()) { + if (def.getParserCode() || def.getPrinterCode() || + def.getAssemblyFormat() || !params.empty()) { os << llvm::formatv(defDeclParsePrintStr, valueType, isAttrGenerator ? ", ::mlir::Type type" : ""); } @@ -410,10 +423,8 @@ def.getParameters(parameters); for (AttrOrTypeParameter ¶meter : parameters) { - SmallString<16> name = parameter.getName(); - name[0] = llvm::toUpper(name[0]); - os << formatv(" {0} get{1}() const;\n", parameter.getCppAccessorType(), - name); + os << formatv(" {0} {1}() const;\n", parameter.getCppAccessorType(), + getParameterAccessorName(parameter.getName())); } } @@ -700,8 +711,32 @@ } void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) { + auto printerCode = def.getPrinterCode(); + auto parserCode = def.getParserCode(); + auto assemblyFormat = def.getAssemblyFormat(); + if (assemblyFormat && (printerCode || parserCode)) { + // Custom assembly format cannot be specified at the same time as either + // custom printer or parser code. + PrintFatalError(def.getLoc(), + def.getName() + ": assembly format cannot be specified at " + "the same time as printer or parser code"); + } + + // Generate a parser and printer based on the assembly format, if specified. + if (assemblyFormat) { + // A custom assembly format requires accessors to be generated for the + // generated printer. + if (!def.genAccessors()) { + PrintFatalError(def.getLoc(), + def.getName() + + ": the generated printer from 'assemblyFormat' " + "requires 'genAccessors' to be true"); + } + return generateAttrOrTypeFormat(def, os); + } + // Emit the printer code, if specified. - if (Optional printerCode = def.getPrinterCode()) { + if (printerCode) { // Both the mnenomic and printerCode must be defined (for parity with // parserCode). os << "void " << def.getCppClassName() @@ -717,7 +752,7 @@ } // Emit the parser code, if specified. - if (Optional parserCode = def.getParserCode()) { + if (parserCode) { FmtContext fmtCtxt; fmtCtxt.addSubst("_parser", "parser") .addSubst("_ctxt", "parser.getContext()"); @@ -857,11 +892,10 @@ paramStorageName = param.getName(); } - SmallString<16> name = param.getName(); - name[0] = llvm::toUpper(name[0]); - os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n", - param.getCppAccessorType(), name, paramStorageName, - def.getCppClassName()); + os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n", + param.getCppAccessorType(), + getParameterAccessorName(param.getName()), + paramStorageName, def.getCppClassName()); } } } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h @@ -0,0 +1,32 @@ +//===- AttrOrTypeFormatGen.h - MLIR attribute and type format generator ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_ +#define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_ + +#include "llvm/Support/raw_ostream.h" + +#include + +namespace mlir { +namespace tblgen { +class AttrOrTypeDef; + +/// Generate a parser and printer based on a custom assembly format for an +/// attribute or type. +void generateAttrOrTypeFormat(const AttrOrTypeDef &def, llvm::raw_ostream &os); + +/// From the parameter name, get the name of the accessor function in camelcase. +/// The first letter of the parameter is upper-cased and prefixed with "get". +/// E.g. 'value' -> 'getValue'. +std::string getParameterAccessorName(llvm::StringRef name); + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_ diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -0,0 +1,781 @@ +//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "AttrOrTypeFormatGen.h" +#include "FormatGen.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/AttrOrTypeDef.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace mlir; +using namespace mlir::tblgen; + +using llvm::formatv; + +//===----------------------------------------------------------------------===// +// Element +//===----------------------------------------------------------------------===// + +namespace { + +/// This class represents a single format element. +class Element { +public: + /// LLVM-style RTTI. + enum class Kind { + /// This element is a directive. + ParamsDirective, + StructDirective, + + /// This element is a literal. + Literal, + + /// This element is a variable. + Variable, + }; + Element(Kind kind) : kind(kind) {} + virtual ~Element() = default; + + /// Return the kind of this element. + Kind getKind() const { return kind; } + +private: + /// The kind of this element. + Kind kind; +}; + +/// This class represents an instance of a literal element. +class LiteralElement : public Element { +public: + LiteralElement(StringRef literal) + : Element(Kind::Literal), literal(literal) {} + + static bool classof(const Element *el) { + return el->getKind() == Kind::Literal; + } + + /// Get the literal spelling. + StringRef getSpelling() const { return literal; } + +private: + /// The spelling of the literal for this element. + StringRef literal; +}; + +/// This class represents an instance of a variable element. A variable refers +/// to an attribute or type parameter. +class VariableElement : public Element { +public: + VariableElement(AttrOrTypeParameter param) + : Element(Kind::Variable), param(param) {} + + static bool classof(const Element *el) { + return el->getKind() == Kind::Variable; + } + + /// Get the parameter in the element. + const AttrOrTypeParameter &getParam() const { return param; } + +private: + AttrOrTypeParameter param; +}; + +/// Base class for a directive that contains references to multiple variables. +template +class ParamsDirectiveBase : public Element { +public: + using Base = ParamsDirectiveBase; + + ParamsDirectiveBase(SmallVector> &¶ms) + : Element(ElementKind), params(std::move(params)) {} + + static bool classof(const Element *el) { + return el->getKind() == ElementKind; + } + + /// Get the parameters contained in this directive. + auto getParams() const { + return llvm::map_range(params, [](auto &el) { + return cast(el.get())->getParam(); + }); + } + + /// Get the number of parameters. + unsigned getNumParams() const { return params.size(); } + + /// Take all of the parameters from this directive. + SmallVector> takeParams() { + return std::move(params); + } + +private: + /// The parameters captured by this directive. + SmallVector> params; +}; + +/// This class represents a `params` directive that refers to all parameters +/// of an attribute or type. When used as a top-level directive, it generates +/// a format of the form: +/// +/// (param-value (`,` param-value)*)? +/// +/// When used as an argument to another directive that accepts variables, +/// `params` can be used in place of manually listing all parameters of an +/// attribute or type. +class ParamsDirective + : public ParamsDirectiveBase { +public: + using Base::Base; +}; + +/// This class represents a `struct` directive that generates a struct format +/// of the form: +/// +/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}` +/// +class StructDirective + : public ParamsDirectiveBase { +public: + using Base::Base; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Format Strings +//===----------------------------------------------------------------------===// + +/// Format for defining an attribute parser. +/// +/// $0: The attribute C++ class name. +static const char *const attrParserDefn = R"( +::mlir::Attribute $0::parse(::mlir::DialectAsmParser &$_parser, + ::mlir::Type $_type) { +)"; + +/// Format for defining a type parser. +/// +/// $0: The type C++ class name. +static const char *const typeParserDefn = R"( +::mlir::Type $0::parse(::mlir::DialectAsmParser &$_parser) { +)"; + +/// Default parser for attribute or type parameters. +static const char *const defaultParameterParser = + "::mlir::FieldParser<$0>::parse($_parser)"; + +/// Default printer for attribute or type parameters. +static const char *const defaultParameterPrinter = "$_printer << $_self"; + +/// Print an error when failing to parse an element. +/// +/// $0: The parameter C++ class name. +static const char *const parseErrorStr = + "$_parser.emitError($_parser.getCurrentLocation(), "; + +/// Format for defining an attribute or type printer. +/// +/// $0: The attribute or type C++ class name. +/// $1: The attribute or type mnemonic. +static const char *const attrOrTypePrinterDefn = R"( +void $0::print(::mlir::DialectAsmPrinter &$_printer) const { + $_printer << "$1"; +)"; + +/// Loop declaration for struct parser. +/// +/// $0: Number of expected parameters. +static const char *const structParseLoopStart = R"( + for (unsigned _index = 0; _index < $0; ++_index) { + StringRef _paramKey; + if ($_parser.parseKeyword(&_paramKey)) { + $_parser.emitError($_parser.getCurrentLocation(), + "expected a parameter name in struct"); + return {}; + } +)"; + +/// Terminator code segment for the struct parser loop. Check for duplicate or +/// unknown parameters. Parse a comma except on the last element. +/// +/// {0}: Code template for printing an error. +/// {1}: Number of elements in the struct. +static const char *const structParseLoopEnd = R"({{ + {0}"duplicate or unknown struct parameter name: ") << _paramKey; + return {{}; + } + if ((_index != {1} - 1) && parser.parseComma()) + return {{}; + } +)"; + +/// Code format to parse a variable. Separate by lines because variable parsers +/// may be generated inside other directives, which requires indentation. +/// +/// {0}: The parameter name. +/// {1}: The parse code for the parameter. +/// {2}: Code template for printing an error. +/// {3}: Name of the attribute or type. +/// {4}: C++ class of the parameter. +static const char *const variableParser[] = { + " // Parse variable '{0}'", + " _result_{0} = {1};", + " if (failed(_result_{0})) {{", + " {2}\"failed to parse {3} parameter '{0}' which is to be a `{4}`\");", + " return {{};", + " }", +}; + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +/// Get a list of an attribute's or type's parameters. These can be wrapper +/// objects around `AttrOrTypeParameter` or string inits. +static auto getParameters(const AttrOrTypeDef &def) { + SmallVector params; + def.getParameters(params); + return params; +} + +//===----------------------------------------------------------------------===// +// AttrOrTypeFormat +//===----------------------------------------------------------------------===// + +namespace { +class AttrOrTypeFormat { +public: + AttrOrTypeFormat(const AttrOrTypeDef &def, + std::vector> &&elements) + : def(def), elements(std::move(elements)) {} + + /// Generate the attribute or type parser. + void genParser(raw_ostream &os); + /// Generate the attribute or type printer. + void genPrinter(raw_ostream &os); + +private: + /// Generate the parser code for a specific format element. + void genElementParser(Element *el, FmtContext &ctx, raw_ostream &os); + /// Generate the parser code for a literal. + void genLiteralParser(StringRef value, FmtContext &ctx, raw_ostream &os, + unsigned indent = 0); + /// Generate the parser code for a variable. + void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx, + raw_ostream &os, unsigned indent = 0); + /// Generate the parser code for a `params` directive. + void genParamsParser(ParamsDirective *el, FmtContext &ctx, raw_ostream &os); + /// Generate the parser code for a `struct` directive. + void genStructParser(StructDirective *el, FmtContext &ctx, raw_ostream &os); + + /// Generate the printer code for a specific format element. + void genElementPrinter(Element *el, FmtContext &ctx, raw_ostream &os); + /// Generate the printer code for a literal. + void genLiteralPrinter(StringRef value, FmtContext &ctx, raw_ostream &os); + /// Generate the printer code for a variable. + void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx, + raw_ostream &os); + /// Generate the printer code for a `params` directive. + void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, raw_ostream &os); + /// Generate the printer code for a `struct` directive. + void genStructPrinter(StructDirective *el, FmtContext &ctx, raw_ostream &os); + + /// The ODS definition of the attribute or type whose format is being used to + /// generate a parser and printer. + const AttrOrTypeDef &def; + /// The list of top-level format elements returned by the assembly format + /// parser. + std::vector> elements; + + /// Flags for printing spaces. + bool shouldEmitSpace; + bool lastWasPunctuation; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ParserGen +//===----------------------------------------------------------------------===// + +void AttrOrTypeFormat::genParser(raw_ostream &os) { + FmtContext ctx; + ctx.addSubst("_parser", "parser"); + + /// Generate the definition. + if (isa(def)) { + ctx.addSubst("_type", "attrType"); + os << tgfmt(attrParserDefn, &ctx, def.getCppClassName()); + } else { + os << tgfmt(typeParserDefn, &ctx, def.getCppClassName()); + } + + /// Declare variables to store all of the parameters. Allocated parameters + /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store + /// FailureOr to defer type construction for parameters that are parsed in + /// a loop (parsers return FailureOr anyways). + SmallVector params = getParameters(def); + for (const AttrOrTypeParameter ¶m : params) { + os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n", + param.getCppStorageType(), param.getName()); + } + + /// Store the initial location of the parser. + ctx.addSubst("_loc", "loc"); + os << tgfmt(" ::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n" + " (void) $_loc;\n", + &ctx); + + /// Generate call to each parameter parser. + for (auto &el : elements) + genElementParser(el.get(), ctx, os); + + /// Generate call to the attribute or type builder. Use the checked getter + /// if one was generated. + if (def.genVerifyDecl()) { + os << tgfmt(" return $_parser.getChecked<$0>($_loc, $_parser.getContext()", + &ctx, def.getCppClassName()); + } else { + os << tgfmt(" return $0::get($_parser.getContext()", &ctx, + def.getCppClassName()); + } + for (const AttrOrTypeParameter ¶m : params) + os << formatv(",\n _result_{0}.getValue()", param.getName()); + os << ");\n}\n\n"; +} + +void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx, + raw_ostream &os) { + if (auto *literal = dyn_cast(el)) + return genLiteralParser(literal->getSpelling(), ctx, os); + if (auto *var = dyn_cast(el)) + return genVariableParser(var->getParam(), ctx, os); + if (auto *params = dyn_cast(el)) + return genParamsParser(params, ctx, os); + if (auto *strct = dyn_cast(el)) + return genStructParser(strct, ctx, os); + + llvm_unreachable("unknown format element"); +} + +void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx, + raw_ostream &os, unsigned indent) { + os.indent(indent) << " // Parse literal '" << value << "'\n"; + os.indent(indent) << tgfmt(" if ($_parser.parse", &ctx); + if (value.front() == '_' || isalpha(value.front())) { + os << "Keyword(\"" << value << "\")"; + } else { + os << StringSwitch(value) + .Case("->", "Arrow") + .Case(":", "Colon") + .Case(",", "Comma") + .Case("=", "Equal") + .Case("<", "Less") + .Case(">", "Greater") + .Case("{", "LBrace") + .Case("}", "RBrace") + .Case("(", "LParen") + .Case(")", "RParen") + .Case("[", "LSquare") + .Case("]", "RSquare") + .Case("?", "Question") + .Case("+", "Plus") + .Case("*", "Star") + << "()"; + } + os << ")\n"; + // Parser will emit an error + os.indent(indent) << " return {};\n"; +} + +void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m, + FmtContext &ctx, raw_ostream &os, + unsigned indent) { + /// Check for a custom parser. Use the default attribute parser otherwise. + auto customParser = param.getParser(); + auto parser = + customParser ? *customParser : StringRef(defaultParameterParser); + for (const char *line : variableParser) { + os.indent(indent) << formatv(line, param.getName(), + tgfmt(parser, &ctx, param.getCppStorageType()), + tgfmt(parseErrorStr, &ctx), def.getName(), + param.getCppType()) + << "\n"; + } +} + +void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, + raw_ostream &os) { + os << " // Parse parameter list\n"; + llvm::interleave( + el->getParams(), [&](auto param) { genVariableParser(param, ctx, os); }, + [&]() { genLiteralParser(",", ctx, os); }); +} + +void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx, + raw_ostream &os) { + os << " // Parse parameter struct\n"; + + /// Declare a "seen" variable for each key. + for (const AttrOrTypeParameter ¶m : el->getParams()) + os << formatv(" bool _seen_{0} = false;\n", param.getName()); + + /// Generate the parsing loop. + os << tgfmt(structParseLoopStart, &ctx, el->getNumParams()); + genLiteralParser("=", ctx, os, 2); + os << " "; + for (const AttrOrTypeParameter ¶m : el->getParams()) { + os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n" + " _seen_{0} = true;\n", + param.getName()); + genVariableParser(param, ctx, os, 4); + os << " } else "; + } + + /// Duplicate or unknown parameter. + os << formatv(structParseLoopEnd, tgfmt(parseErrorStr, &ctx), + el->getNumParams()); + + /// Because the loop loops N times and each non-failing iteration sets 1 of + /// N flags, successfully exiting the loop means that all parameters have been + /// seen. `parseOptionalComma` would cause issues with any formats that use + /// "struct(...) `,`" beacuse structs aren't sounded by braces. +} + +//===----------------------------------------------------------------------===// +// PrinterGen +//===----------------------------------------------------------------------===// + +void AttrOrTypeFormat::genPrinter(raw_ostream &os) { + FmtContext ctx; + ctx.addSubst("_printer", "printer"); + + /// Generate the definition. + os << tgfmt(attrOrTypePrinterDefn, &ctx, def.getCppClassName(), + *def.getMnemonic()); + + /// Generate printers. + shouldEmitSpace = true; + lastWasPunctuation = false; + for (auto &el : elements) + genElementPrinter(el.get(), ctx, os); + + os << "}\n\n"; +} + +void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx, + raw_ostream &os) { + if (auto *literal = dyn_cast(el)) + return genLiteralPrinter(literal->getSpelling(), ctx, os); + if (auto *params = dyn_cast(el)) + return genParamsPrinter(params, ctx, os); + if (auto *strct = dyn_cast(el)) + return genStructPrinter(strct, ctx, os); + if (auto *var = dyn_cast(el)) + return genVariablePrinter(var->getParam(), ctx, os); + + llvm_unreachable("unknown format element"); +} + +void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, + raw_ostream &os) { + /// Don't insert a space before certain punctuation. + bool needSpace = + shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation); + os << tgfmt(" $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "", + value); + + /// Update the flags. + shouldEmitSpace = + value.size() != 1 || !StringRef("<({[").contains(value.front()); + lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); +} + +void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m, + FmtContext &ctx, raw_ostream &os) { + /// Insert a space before the next parameter, if necessary. + if (shouldEmitSpace || !lastWasPunctuation) + os << tgfmt(" $_printer << ' ';\n", &ctx); + shouldEmitSpace = true; + lastWasPunctuation = false; + + ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); + os << " "; + if (auto printer = param.getPrinter()) + os << tgfmt(*printer, &ctx) << ";\n"; + else + os << tgfmt(defaultParameterPrinter, &ctx) << ";\n"; +} + +void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, + raw_ostream &os) { + llvm::interleave( + el->getParams(), [&](auto param) { genVariablePrinter(param, ctx, os); }, + [&]() { genLiteralPrinter(",", ctx, os); }); +} + +void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, + raw_ostream &os) { + llvm::interleave( + el->getParams(), + [&](auto param) { + genLiteralPrinter(param.getName(), ctx, os); + genLiteralPrinter("=", ctx, os); + os << tgfmt(" $_printer << ' ';\n", &ctx); + genVariablePrinter(param, ctx, os); + }, + [&]() { genLiteralPrinter(",", ctx, os); }); +} + +//===----------------------------------------------------------------------===// +// FormatParser +//===----------------------------------------------------------------------===// + +namespace { +class FormatParser { +public: + FormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def) + : lexer(mgr, def.getLoc()[0]), curToken(lexer.lexToken()), def(def), + seenParams(def.getNumParameters()) {} + + /// Parse the attribute or type format and create the format elements. + FailureOr parse(); + +private: + /// The current context of the parser when parsing an element. + enum ParserContext { + /// The element is being parsed in the default context - at the top of the + /// format + TopLevelContext, + /// The element is being parsed as a child to a `struct` directive. + StructDirective, + }; + + /// Emit an error. + LogicalResult emitError(const Twine &msg) { + lexer.emitError(curToken.getLoc(), msg); + return failure(); + } + + /// Parse an expected token. + LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) { + if (curToken.getKind() != kind) + return emitError(msg); + consumeToken(); + return success(); + } + + /// Advance the lexer to the next token. + void consumeToken() { + assert(curToken.getKind() != FormatToken::eof && + curToken.getKind() != FormatToken::error && + "shouldn't advance past EOF or errors"); + curToken = lexer.lexToken(); + } + + /// Parse any element. + FailureOr> parseElement(ParserContext ctx); + /// Parse a literal element. + FailureOr> parseLiteral(ParserContext ctx); + /// Parse a variable element. + FailureOr> parseVariable(ParserContext ctx); + /// Parse a directive. + FailureOr> parseDirective(ParserContext ctx); + /// Parse a `params` directive. + FailureOr> parseParamsDirective(); + /// Parse a `struct` directive. + FailureOr> parseStructDirective(); + + /// The current format lexer. + FormatLexer lexer; + /// The current token in the stream. + FormatToken curToken; + /// Attribute or type tablegen def. + const AttrOrTypeDef &def; + + /// Seen attribute or type parameters. + llvm::BitVector seenParams; +}; +} // end anonymous namespace + +FailureOr FormatParser::parse() { + std::vector> elements; + elements.reserve(16); + + /// Parse the format elements. + while (curToken.getKind() != FormatToken::eof) { + auto element = parseElement(TopLevelContext); + if (failed(element)) + return failure(); + + /// Add the format element and continue. + elements.push_back(std::move(*element)); + } + + /// Check that all parameters have been seen. + SmallVector params = getParameters(def); + for (auto it : llvm::enumerate(params)) { + if (!seenParams.test(it.index())) { + return emitError("format is missing reference to parameter: " + + it.value().getName()); + } + } + + return AttrOrTypeFormat(def, std::move(elements)); +} + +FailureOr> +FormatParser::parseElement(ParserContext ctx) { + if (curToken.getKind() == FormatToken::literal) + return parseLiteral(ctx); + if (curToken.getKind() == FormatToken::variable) + return parseVariable(ctx); + if (curToken.isKeyword()) + return parseDirective(ctx); + + return emitError("expected literal, directive, or variable"); +} + +FailureOr> +FormatParser::parseLiteral(ParserContext ctx) { + if (ctx != TopLevelContext) { + return emitError( + "literals may only be used in the top-level section of the format"); + } + + /// Get the literal spelling without the surrounding "`". + auto value = curToken.getSpelling().drop_front().drop_back(); + if (!isValidLiteral(value)) + return emitError("literal '" + value + "' is not valid"); + + consumeToken(); + return {std::make_unique(value)}; +} + +FailureOr> +FormatParser::parseVariable(ParserContext ctx) { + /// Get the parameter name without the preceding "$". + auto name = curToken.getSpelling().drop_front(); + + /// Lookup the parameter. + SmallVector params = getParameters(def); + auto *it = llvm::find_if( + params, [&](auto ¶m) { return param.getName() == name; }); + + /// Check that the parameter reference is valid. + if (it == params.end()) + return emitError(def.getName() + " has no parameter named '" + name + "'"); + auto idx = std::distance(params.begin(), it); + if (seenParams.test(idx)) + return emitError("duplicate parameter '" + name + "'"); + seenParams.set(idx); + + consumeToken(); + return {std::make_unique(*it)}; +} + +FailureOr> +FormatParser::parseDirective(ParserContext ctx) { + + switch (curToken.getKind()) { + case FormatToken::kw_params: + return parseParamsDirective(); + case FormatToken::kw_struct: + if (ctx != TopLevelContext) { + return emitError( + "`struct` may only be used in the top-level section of the format"); + } + return parseStructDirective(); + default: + return emitError("unknown directive in format: " + curToken.getSpelling()); + } +} + +FailureOr> FormatParser::parseParamsDirective() { + consumeToken(); + /// Collect all of the attribute's or type's parameters. + SmallVector params = getParameters(def); + SmallVector> vars; + /// Ensure that none of the parameters have already been captured. + for (auto it : llvm::enumerate(params)) { + if (seenParams.test(it.index())) { + return emitError("`params` captures duplicate parameter: " + + it.value().getName()); + } + seenParams.set(it.index()); + vars.push_back(std::make_unique(it.value())); + } + return {std::make_unique(std::move(vars))}; +} + +FailureOr> FormatParser::parseStructDirective() { + consumeToken(); + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before `struct` argument list"))) + return failure(); + + /// Parse variables captured by `struct`. + SmallVector> vars; + + /// Parse first captured parameter or a `params` directive. + FailureOr> var = parseElement(StructDirective); + if (failed(var) || !isa(*var)) + return emitError("`struct` argument list expected a variable or directive"); + if (isa(*var)) { + /// Parse any other parameters. + vars.push_back(std::move(*var)); + while (curToken.getKind() == FormatToken::comma) { + consumeToken(); + var = parseElement(StructDirective); + if (failed(var) || !isa(*var)) + return emitError("expected a variable in `struct` argument list"); + vars.push_back(std::move(*var)); + } + } else { + /// `struct(params)` captures all parameters in the attribute or type. + vars = cast(var->get())->takeParams(); + } + + if (curToken.getKind() != FormatToken::r_paren) + return emitError("expected ')' at the end of an argument list"); + + consumeToken(); + return {std::make_unique<::StructDirective>(std::move(vars))}; +} + +//===----------------------------------------------------------------------===// +// Interface +//===----------------------------------------------------------------------===// + +void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def, + raw_ostream &os) { + llvm::SourceMgr mgr; + mgr.AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), + llvm::SMLoc()); + + /// Parse the custom assembly format> + FormatParser parser(mgr, def); + FailureOr format = parser.parse(); + if (failed(format)) { + if (formatErrorIsFatal) + PrintFatalError(def.getLoc(), "failed to parse assembly format"); + return; + } + + /// Generate the parser and printer. + format->genParser(os); + format->genPrinter(os); +} diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -6,10 +6,12 @@ add_tablegen(mlir-tblgen MLIR AttrOrTypeDefGen.cpp + AttrOrTypeFormatGen.cpp CodeGenHelpers.cpp DialectGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp + FormatGen.cpp LLVMIRConversionGen.cpp LLVMIRIntrinsicGen.cpp mlir-tblgen.cpp diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -0,0 +1,161 @@ +//===- FormatGen.h - Utilities for custom assembly formats ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains common classes for building custom assembly format parsers +// and generators. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_ +#define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/SMLoc.h" + +namespace llvm { +class SourceMgr; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +//===----------------------------------------------------------------------===// +// FormatToken +//===----------------------------------------------------------------------===// + +/// This class represents a specific token in the input format. +class FormatToken { +public: + /// Basic token kinds. + enum Kind { + // Markers. + eof, + error, + + // Tokens with no info. + l_paren, + r_paren, + caret, + colon, + comma, + equal, + less, + greater, + question, + star, + + // Keywords. + keyword_start, + kw_attr_dict, + kw_attr_dict_w_keyword, + kw_custom, + kw_functional_type, + kw_operands, + kw_params, + kw_ref, + kw_regions, + kw_results, + kw_struct, + kw_successors, + kw_type, + keyword_end, + + // String valued tokens. + identifier, + literal, + variable, + }; + + FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} + + /// Return the bytes that make up this token. + StringRef getSpelling() const { return spelling; } + + /// Return the kind of this token. + Kind getKind() const { return kind; } + + /// Return a location for this token. + llvm::SMLoc getLoc() const; + + /// Return if this token is a keyword. + bool isKeyword() const { + return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end; + } + +private: + /// Discriminator that indicates the kind of token this is. + Kind kind; + + /// A reference to the entire token contents; this is always a pointer into + /// a memory buffer owned by the source manager. + StringRef spelling; +}; + +//===----------------------------------------------------------------------===// +// FormatLexer +//===----------------------------------------------------------------------===// + +/// This class implements a simple lexer for operation assembly format strings. +class FormatLexer { +public: + FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc); + + /// Lex the next token and return it. + FormatToken lexToken(); + + /// Emit an error to the lexer with the given location and message. + FormatToken emitError(llvm::SMLoc loc, const Twine &msg); + FormatToken emitError(const char *loc, const Twine &msg); + + FormatToken emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, + const Twine ¬e); + +private: + /// Return the next character in the stream. + int getNextChar(); + + /// Lex an identifier, literal, or variable. + FormatToken lexIdentifier(const char *tokStart); + FormatToken lexLiteral(const char *tokStart); + FormatToken lexVariable(const char *tokStart); + + /// Create a token with the current pointer and a start pointer. + FormatToken formToken(FormatToken::Kind kind, const char *tokStart) { + return FormatToken(kind, StringRef(tokStart, curPtr - tokStart)); + } + + /// The source manager containing the format string. + llvm::SourceMgr &mgr; + /// Location of the format string. + llvm::SMLoc loc; + /// Buffer containing the format string. + StringRef curBuffer; + /// Current pointer in the buffer. + const char *curPtr; +}; + +/// Whether a space needs to be emitted before a literal. E.g., two keywords +/// back-to-back require a space separator, but a keyword followed by '<' does +/// not require a space. +bool shouldEmitSpaceBefore(StringRef value, bool lastWasPunctuation); + +/// Returns true if the given string can be formatted as a keyword. +bool canFormatStringAsKeyword(StringRef value); + +/// Returns true if the given string is valid format literal element. +bool isValidLiteral(StringRef value); + +/// Whether a failure in parsing the assembly format should be a fatal error. +extern llvm::cl::opt formatErrorIsFatal; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_ diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -0,0 +1,225 @@ +//===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "FormatGen.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/TableGen/Error.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// FormatToken +//===----------------------------------------------------------------------===// + +llvm::SMLoc FormatToken::getLoc() const { + return llvm::SMLoc::getFromPointer(spelling.data()); +} + +//===----------------------------------------------------------------------===// +// FormatLexer +//===----------------------------------------------------------------------===// + +FormatLexer::FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc) + : mgr(mgr), loc(loc), + curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()), + curPtr(curBuffer.begin()) {} + +FormatToken FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) { + mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); + llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note, + "in custom assembly format for this operation"); + return formToken(FormatToken::error, loc.getPointer()); +} + +FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) { + return emitError(llvm::SMLoc::getFromPointer(loc), msg); +} + +FormatToken FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, + const Twine ¬e) { + mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); + llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note, + "in custom assembly format for this operation"); + mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note); + return formToken(FormatToken::error, loc.getPointer()); +} + +int FormatLexer::getNextChar() { + char curChar = *curPtr++; + switch (curChar) { + default: + return (unsigned char)curChar; + case 0: { + // A nul character in the stream is either the end of the current buffer or + // a random nul in the file. Disambiguate that here. + if (curPtr - 1 != curBuffer.end()) + return 0; + + // Otherwise, return end of file. + --curPtr; + return EOF; + } + case '\n': + case '\r': + // Handle the newline character by ignoring it and incrementing the line + // count. However, be careful about 'dos style' files with \n\r in them. + // Only treat a \n\r or \r\n as a single line. + if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) + ++curPtr; + return '\n'; + } +} + +FormatToken FormatLexer::lexToken() { + const char *tokStart = curPtr; + + // This always consumes at least one character. + int curChar = getNextChar(); + switch (curChar) { + default: + // Handle identifiers: [a-zA-Z_] + if (isalpha(curChar) || curChar == '_') + return lexIdentifier(tokStart); + + // Unknown character, emit an error. + return emitError(tokStart, "unexpected character"); + case EOF: + // Return EOF denoting the end of lexing. + return formToken(FormatToken::eof, tokStart); + + // Lex punctuation. + case '^': + return formToken(FormatToken::caret, tokStart); + case ':': + return formToken(FormatToken::colon, tokStart); + case ',': + return formToken(FormatToken::comma, tokStart); + case '=': + return formToken(FormatToken::equal, tokStart); + case '<': + return formToken(FormatToken::less, tokStart); + case '>': + return formToken(FormatToken::greater, tokStart); + case '?': + return formToken(FormatToken::question, tokStart); + case '(': + return formToken(FormatToken::l_paren, tokStart); + case ')': + return formToken(FormatToken::r_paren, tokStart); + case '*': + return formToken(FormatToken::star, tokStart); + + // Ignore whitespace characters. + case 0: + case ' ': + case '\t': + case '\n': + return lexToken(); + + case '`': + return lexLiteral(tokStart); + case '$': + return lexVariable(tokStart); + } +} + +FormatToken FormatLexer::lexLiteral(const char *tokStart) { + assert(curPtr[-1] == '`'); + + // Lex a literal surrounded by ``. + while (const char curChar = *curPtr++) { + if (curChar == '`') + return formToken(FormatToken::literal, tokStart); + } + return emitError(curPtr - 1, "unexpected end of file in literal"); +} + +FormatToken FormatLexer::lexVariable(const char *tokStart) { + if (!isalpha(curPtr[0]) && curPtr[0] != '_') + return emitError(curPtr - 1, "expected variable name"); + + // Otherwise, consume the rest of the characters. + while (isalnum(*curPtr) || *curPtr == '_') + ++curPtr; + return formToken(FormatToken::variable, tokStart); +} + +FormatToken FormatLexer::lexIdentifier(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* + while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') + ++curPtr; + + // Check to see if this identifier is a keyword. + StringRef str(tokStart, curPtr - tokStart); + auto kind = + StringSwitch(str) + .Case("attr-dict", FormatToken::kw_attr_dict) + .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword) + .Case("custom", FormatToken::kw_custom) + .Case("functional-type", FormatToken::kw_functional_type) + .Case("operands", FormatToken::kw_operands) + .Case("params", FormatToken::kw_params) + .Case("ref", FormatToken::kw_ref) + .Case("regions", FormatToken::kw_regions) + .Case("results", FormatToken::kw_results) + .Case("struct", FormatToken::kw_struct) + .Case("successors", FormatToken::kw_successors) + .Case("type", FormatToken::kw_type) + .Default(FormatToken::identifier); + return FormatToken(kind, str); +} + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value, + bool lastWasPunctuation) { + if (value.size() != 1 && value != "->") + return true; + if (lastWasPunctuation) + return !StringRef(">)}],").contains(value.front()); + return !StringRef("<>(){}[],").contains(value.front()); +} + +bool mlir::tblgen::canFormatStringAsKeyword(StringRef value) { + if (!isalpha(value.front()) && value.front() != '_') + return false; + return llvm::all_of(value.drop_front(), [](char c) { + return isalnum(c) || c == '_' || c == '$' || c == '.'; + }); +} + +bool mlir::tblgen::isValidLiteral(StringRef value) { + if (value.empty()) + return false; + char front = value.front(); + + // If there is only one character, this must either be punctuation or a + // single character bare identifier. + if (value.size() == 1) + return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front); + + // Check the punctuation that are larger than a single character. + if (value == "->") + return true; + + // Otherwise, this must be an identifier. + return canFormatStringAsKeyword(value); +} + +//===----------------------------------------------------------------------===// +// Commandline Options +//===----------------------------------------------------------------------===// + +llvm::cl::opt mlir::tblgen::formatErrorIsFatal( + "asmformat-error-is-fatal", + llvm::cl::desc("Emit a fatal error if format parsing fails"), + llvm::cl::init(true)); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "OpFormatGen.h" +#include "FormatGen.h" #include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" @@ -20,7 +21,6 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -30,20 +30,6 @@ using namespace mlir; using namespace mlir::tblgen; -static llvm::cl::opt formatErrorIsFatal( - "asmformat-error-is-fatal", - llvm::cl::desc("Emit a fatal error if format parsing fails"), - llvm::cl::init(true)); - -/// Returns true if the given string can be formatted as a keyword. -static bool canFormatStringAsKeyword(StringRef value) { - if (!isalpha(value.front()) && value.front() != '_') - return false; - return llvm::all_of(value.drop_front(), [](char c) { - return isalnum(c) || c == '_' || c == '$' || c == '.'; - }); -} - //===----------------------------------------------------------------------===// // Element //===----------------------------------------------------------------------===// @@ -273,33 +259,12 @@ /// Return the literal for this element. StringRef getLiteral() const { return literal; } - /// Returns true if the given string is a valid literal. - static bool isValidLiteral(StringRef value); - private: /// The spelling of the literal for this element. StringRef literal; }; } // end anonymous namespace -bool LiteralElement::isValidLiteral(StringRef value) { - if (value.empty()) - return false; - char front = value.front(); - - // If there is only one character, this must either be punctuation or a - // single character bare identifier. - if (value.size() == 1) - return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front); - - // Check the punctuation that are larger than a single character. - if (value == "->") - return true; - - // Otherwise, this must be an identifier. - return canFormatStringAsKeyword(value); -} - //===----------------------------------------------------------------------===// // WhitespaceElement @@ -1703,14 +1668,7 @@ body << " p"; // Don't insert a space for certain punctuation. - auto shouldPrintSpaceBeforeLiteral = [&] { - if (value.size() != 1 && value != "->") - return true; - if (lastWasPunctuation) - return !StringRef(">)}],").contains(value.front()); - return !StringRef("<>(){}[],").contains(value.front()); - }; - if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral()) + if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation)) body << " << ' '"; body << " << \"" << value << "\";\n"; @@ -2080,253 +2038,6 @@ lastWasPunctuation); } -//===----------------------------------------------------------------------===// -// FormatLexer -//===----------------------------------------------------------------------===// - -namespace { -/// This class represents a specific token in the input format. -class Token { -public: - enum Kind { - // Markers. - eof, - error, - - // Tokens with no info. - l_paren, - r_paren, - caret, - colon, - comma, - equal, - less, - greater, - question, - - // Keywords. - keyword_start, - kw_attr_dict, - kw_attr_dict_w_keyword, - kw_custom, - kw_functional_type, - kw_operands, - kw_ref, - kw_regions, - kw_results, - kw_successors, - kw_type, - keyword_end, - - // String valued tokens. - identifier, - literal, - variable, - }; - Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} - - /// Return the bytes that make up this token. - StringRef getSpelling() const { return spelling; } - - /// Return the kind of this token. - Kind getKind() const { return kind; } - - /// Return a location for this token. - llvm::SMLoc getLoc() const { - return llvm::SMLoc::getFromPointer(spelling.data()); - } - - /// Return if this token is a keyword. - bool isKeyword() const { return kind > keyword_start && kind < keyword_end; } - -private: - /// Discriminator that indicates the kind of token this is. - Kind kind; - - /// A reference to the entire token contents; this is always a pointer into - /// a memory buffer owned by the source manager. - StringRef spelling; -}; - -/// This class implements a simple lexer for operation assembly format strings. -class FormatLexer { -public: - FormatLexer(llvm::SourceMgr &mgr, Operator &op); - - /// Lex the next token and return it. - Token lexToken(); - - /// Emit an error to the lexer with the given location and message. - Token emitError(llvm::SMLoc loc, const Twine &msg); - Token emitError(const char *loc, const Twine &msg); - - Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine ¬e); - -private: - Token formToken(Token::Kind kind, const char *tokStart) { - return Token(kind, StringRef(tokStart, curPtr - tokStart)); - } - - /// Return the next character in the stream. - int getNextChar(); - - /// Lex an identifier, literal, or variable. - Token lexIdentifier(const char *tokStart); - Token lexLiteral(const char *tokStart); - Token lexVariable(const char *tokStart); - - llvm::SourceMgr &srcMgr; - Operator &op; - StringRef curBuffer; - const char *curPtr; -}; -} // end anonymous namespace - -FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op) - : srcMgr(mgr), op(op) { - curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer(); - curPtr = curBuffer.begin(); -} - -Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) { - srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); - llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note, - "in custom assembly format for this operation"); - return formToken(Token::error, loc.getPointer()); -} -Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, - const Twine ¬e) { - srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); - llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note, - "in custom assembly format for this operation"); - srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note); - return formToken(Token::error, loc.getPointer()); -} -Token FormatLexer::emitError(const char *loc, const Twine &msg) { - return emitError(llvm::SMLoc::getFromPointer(loc), msg); -} - -int FormatLexer::getNextChar() { - char curChar = *curPtr++; - switch (curChar) { - default: - return (unsigned char)curChar; - case 0: { - // A nul character in the stream is either the end of the current buffer or - // a random nul in the file. Disambiguate that here. - if (curPtr - 1 != curBuffer.end()) - return 0; - - // Otherwise, return end of file. - --curPtr; - return EOF; - } - case '\n': - case '\r': - // Handle the newline character by ignoring it and incrementing the line - // count. However, be careful about 'dos style' files with \n\r in them. - // Only treat a \n\r or \r\n as a single line. - if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) - ++curPtr; - return '\n'; - } -} - -Token FormatLexer::lexToken() { - const char *tokStart = curPtr; - - // This always consumes at least one character. - int curChar = getNextChar(); - switch (curChar) { - default: - // Handle identifiers: [a-zA-Z_] - if (isalpha(curChar) || curChar == '_') - return lexIdentifier(tokStart); - - // Unknown character, emit an error. - return emitError(tokStart, "unexpected character"); - case EOF: - // Return EOF denoting the end of lexing. - return formToken(Token::eof, tokStart); - - // Lex punctuation. - case '^': - return formToken(Token::caret, tokStart); - case ':': - return formToken(Token::colon, tokStart); - case ',': - return formToken(Token::comma, tokStart); - case '=': - return formToken(Token::equal, tokStart); - case '<': - return formToken(Token::less, tokStart); - case '>': - return formToken(Token::greater, tokStart); - case '?': - return formToken(Token::question, tokStart); - case '(': - return formToken(Token::l_paren, tokStart); - case ')': - return formToken(Token::r_paren, tokStart); - - // Ignore whitespace characters. - case 0: - case ' ': - case '\t': - case '\n': - return lexToken(); - - case '`': - return lexLiteral(tokStart); - case '$': - return lexVariable(tokStart); - } -} - -Token FormatLexer::lexLiteral(const char *tokStart) { - assert(curPtr[-1] == '`'); - - // Lex a literal surrounded by ``. - while (const char curChar = *curPtr++) { - if (curChar == '`') - return formToken(Token::literal, tokStart); - } - return emitError(curPtr - 1, "unexpected end of file in literal"); -} - -Token FormatLexer::lexVariable(const char *tokStart) { - if (!isalpha(curPtr[0]) && curPtr[0] != '_') - return emitError(curPtr - 1, "expected variable name"); - - // Otherwise, consume the rest of the characters. - while (isalnum(*curPtr) || *curPtr == '_') - ++curPtr; - return formToken(Token::variable, tokStart); -} - -Token FormatLexer::lexIdentifier(const char *tokStart) { - // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* - while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') - ++curPtr; - - // Check to see if this identifier is a keyword. - StringRef str(tokStart, curPtr - tokStart); - Token::Kind kind = - StringSwitch(str) - .Case("attr-dict", Token::kw_attr_dict) - .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword) - .Case("custom", Token::kw_custom) - .Case("functional-type", Token::kw_functional_type) - .Case("operands", Token::kw_operands) - .Case("ref", Token::kw_ref) - .Case("regions", Token::kw_regions) - .Case("results", Token::kw_results) - .Case("successors", Token::kw_successors) - .Case("type", Token::kw_type) - .Default(Token::identifier); - return Token(kind, str); -} - //===----------------------------------------------------------------------===// // FormatParser //===----------------------------------------------------------------------===// @@ -2345,8 +2056,8 @@ class FormatParser { public: FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) - : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op), - seenOperandTypes(op.getNumOperands()), + : lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format), + op(op), seenOperandTypes(op.getNumOperands()), seenResultTypes(op.getNumResults()) {} /// Parse the operation assembly format. @@ -2448,7 +2159,8 @@ LogicalResult parseCustomDirectiveParameter( std::vector> ¶meters); LogicalResult parseFunctionalTypeDirective(std::unique_ptr &element, - Token tok, ParserContext context); + FormatToken tok, + ParserContext context); LogicalResult parseOperandsDirective(std::unique_ptr &element, llvm::SMLoc loc, ParserContext context); LogicalResult parseReferenceDirective(std::unique_ptr &element, @@ -2460,8 +2172,8 @@ LogicalResult parseSuccessorsDirective(std::unique_ptr &element, llvm::SMLoc loc, ParserContext context); - LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, - ParserContext context); + LogicalResult parseTypeDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context); LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, bool isRefChild = false); @@ -2471,12 +2183,12 @@ /// Advance the current lexer onto the next token. void consumeToken() { - assert(curToken.getKind() != Token::eof && - curToken.getKind() != Token::error && + assert(curToken.getKind() != FormatToken::eof && + curToken.getKind() != FormatToken::error && "shouldn't advance past EOF or errors"); curToken = lexer.lexToken(); } - LogicalResult parseToken(Token::Kind kind, const Twine &msg) { + LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) { if (curToken.getKind() != kind) return emitError(curToken.getLoc(), msg); consumeToken(); @@ -2497,7 +2209,7 @@ //===--------------------------------------------------------------------===// FormatLexer lexer; - Token curToken; + FormatToken curToken; OperationFormat &fmt; Operator &op; @@ -2518,7 +2230,7 @@ llvm::SMLoc loc = curToken.getLoc(); // Parse each of the format elements into the main format. - while (curToken.getKind() != Token::eof) { + while (curToken.getKind() != FormatToken::eof) { std::unique_ptr element; if (failed(parseElement(element, TopLevelContext))) return ::mlir::failure(); @@ -2843,13 +2555,13 @@ if (curToken.isKeyword()) return parseDirective(element, context); // Literals. - if (curToken.getKind() == Token::literal) + if (curToken.getKind() == FormatToken::literal) return parseLiteral(element, context); // Optionals. - if (curToken.getKind() == Token::l_paren) + if (curToken.getKind() == FormatToken::l_paren) return parseOptional(element, context); // Variables. - if (curToken.getKind() == Token::variable) + if (curToken.getKind() == FormatToken::variable) return parseVariable(element, context); return emitError(curToken.getLoc(), "expected directive, literal, variable, or optional group"); @@ -2857,7 +2569,7 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr &element, ParserContext context) { - Token varTok = curToken; + FormatToken varTok = curToken; consumeToken(); StringRef name = varTok.getSpelling().drop_front(); @@ -2937,31 +2649,31 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr &element, ParserContext context) { - Token dirTok = curToken; + FormatToken dirTok = curToken; consumeToken(); switch (dirTok.getKind()) { - case Token::kw_attr_dict: + case FormatToken::kw_attr_dict: return parseAttrDictDirective(element, dirTok.getLoc(), context, /*withKeyword=*/false); - case Token::kw_attr_dict_w_keyword: + case FormatToken::kw_attr_dict_w_keyword: return parseAttrDictDirective(element, dirTok.getLoc(), context, /*withKeyword=*/true); - case Token::kw_custom: + case FormatToken::kw_custom: return parseCustomDirective(element, dirTok.getLoc(), context); - case Token::kw_functional_type: + case FormatToken::kw_functional_type: return parseFunctionalTypeDirective(element, dirTok, context); - case Token::kw_operands: + case FormatToken::kw_operands: return parseOperandsDirective(element, dirTok.getLoc(), context); - case Token::kw_regions: + case FormatToken::kw_regions: return parseRegionsDirective(element, dirTok.getLoc(), context); - case Token::kw_results: + case FormatToken::kw_results: return parseResultsDirective(element, dirTok.getLoc(), context); - case Token::kw_successors: + case FormatToken::kw_successors: return parseSuccessorsDirective(element, dirTok.getLoc(), context); - case Token::kw_ref: + case FormatToken::kw_ref: return parseReferenceDirective(element, dirTok.getLoc(), context); - case Token::kw_type: + case FormatToken::kw_type: return parseTypeDirective(element, dirTok, context); default: @@ -2971,7 +2683,7 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr &element, ParserContext context) { - Token literalTok = curToken; + FormatToken literalTok = curToken; if (context != TopLevelContext) { return emitError( literalTok.getLoc(), @@ -2993,7 +2705,7 @@ } // Check that the parsed literal is valid. - if (!LiteralElement::isValidLiteral(value)) + if (!isValidLiteral(value)) return emitError(literalTok.getLoc(), "expected valid literal"); element = std::make_unique(value); @@ -3014,14 +2726,15 @@ do { if (failed(parseOptionalChildElement(thenElements, anchorIdx))) return ::mlir::failure(); - } while (curToken.getKind() != Token::r_paren); + } while (curToken.getKind() != FormatToken::r_paren); consumeToken(); // Parse the `else` elements of this optional group. - if (curToken.getKind() == Token::colon) { + if (curToken.getKind() == FormatToken::colon) { consumeToken(); - if (failed(parseToken(Token::l_paren, "expected '(' to start else branch " - "of optional group"))) + if (failed(parseToken(FormatToken::l_paren, + "expected '(' to start else branch " + "of optional group"))) return failure(); do { llvm::SMLoc childLoc = curToken.getLoc(); @@ -3030,11 +2743,12 @@ failed(verifyOptionalChildElement(elseElements.back().get(), childLoc, /*isAnchor=*/false))) return failure(); - } while (curToken.getKind() != Token::r_paren); + } while (curToken.getKind() != FormatToken::r_paren); consumeToken(); } - if (failed(parseToken(Token::question, "expected '?' after optional group"))) + if (failed(parseToken(FormatToken::question, + "expected '?' after optional group"))) return ::mlir::failure(); // The optional group is required to have an anchor. @@ -3069,7 +2783,7 @@ return ::mlir::failure(); // Check to see if this element is the anchor of the optional group. - bool isAnchor = curToken.getKind() == Token::caret; + bool isAnchor = curToken.getKind() == FormatToken::caret; if (isAnchor) { if (anchorIdx) return emitError(childLoc, "only one element can be marked as the anchor " @@ -3173,16 +2887,16 @@ return emitError(loc, "'custom' is only valid as a top-level directive"); // Parse the custom directive name. - if (failed( - parseToken(Token::less, "expected '<' before custom directive name"))) + if (failed(parseToken(FormatToken::less, + "expected '<' before custom directive name"))) return ::mlir::failure(); - Token nameTok = curToken; - if (failed(parseToken(Token::identifier, + FormatToken nameTok = curToken; + if (failed(parseToken(FormatToken::identifier, "expected custom directive name identifier")) || - failed(parseToken(Token::greater, + failed(parseToken(FormatToken::greater, "expected '>' after custom directive name")) || - failed(parseToken(Token::l_paren, + failed(parseToken(FormatToken::l_paren, "expected '(' before custom directive parameters"))) return ::mlir::failure(); @@ -3191,12 +2905,12 @@ do { if (failed(parseCustomDirectiveParameter(elements))) return ::mlir::failure(); - if (curToken.getKind() != Token::comma) + if (curToken.getKind() != FormatToken::comma) break; consumeToken(); } while (true); - if (failed(parseToken(Token::r_paren, + if (failed(parseToken(FormatToken::r_paren, "expected ')' after custom directive parameters"))) return ::mlir::failure(); @@ -3233,9 +2947,8 @@ return ::mlir::success(); } -LogicalResult -FormatParser::parseFunctionalTypeDirective(std::unique_ptr &element, - Token tok, ParserContext context) { +LogicalResult FormatParser::parseFunctionalTypeDirective( + std::unique_ptr &element, FormatToken tok, ParserContext context) { llvm::SMLoc loc = tok.getLoc(); if (context != TopLevelContext) return emitError( @@ -3243,11 +2956,14 @@ // Parse the main operand. std::unique_ptr inputs, results; - if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || failed(parseTypeDirectiveOperand(inputs)) || - failed(parseToken(Token::comma, "expected ',' after inputs argument")) || + failed(parseToken(FormatToken::comma, + "expected ',' after inputs argument")) || failed(parseTypeDirectiveOperand(results)) || - failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); element = std::make_unique(std::move(inputs), std::move(results)); @@ -3278,9 +2994,11 @@ return emitError(loc, "'ref' is only valid within a `custom` directive"); std::unique_ptr operand; - if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || failed(parseElement(operand, RefDirectiveContext)) || - failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); element = std::make_unique(std::move(operand)); @@ -3339,17 +3057,19 @@ } LogicalResult -FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, - ParserContext context) { +FormatParser::parseTypeDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context) { llvm::SMLoc loc = tok.getLoc(); if (context == TypeDirectiveContext) return emitError(loc, "'type' cannot be used as a child of another `type`"); bool isRefChild = context == RefDirectiveContext; std::unique_ptr operand; - if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || failed(parseTypeDirectiveOperand(operand, isRefChild)) || - failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); element = std::make_unique(std::move(operand)); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -220,6 +220,7 @@ "//mlir:SideEffects", "//mlir:StandardOps", "//mlir:StandardOpsTransforms", + "//mlir:Support", "//mlir:TensorDialect", "//mlir:TransformUtils", "//mlir:Transforms",