diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -47,6 +47,74 @@ virtual StringRef getFullSymbolSpec() const = 0; }; +//===----------------------------------------------------------------------===// +// Parse Fields +//===----------------------------------------------------------------------===// + +/// Provide a template class that can be specialized by users to dispatch to +/// parsers. Auto-generated parsers generate calls to `FieldParser::parse`, +/// where `T` is the parameter storage type, to parse custom types. +template +struct FieldParser; + +/// Parse an attribute. +template +struct FieldParser< + AttributeT, std::enable_if_t::value, + AttributeT>> { + static FailureOr parse(DialectAsmParser &parser) { + AttributeT value; + if (parser.parseAttribute(value)) + return failure(); + return value; + } +}; + +/// Parse any integer. +template +struct FieldParser::value, IntT>> { + static FailureOr parse(DialectAsmParser &parser) { + IntT value; + if (parser.parseInteger(value)) + return failure(); + return value; + } +}; + +/// Parse a string. +template <> +struct FieldParser { + static FailureOr parse(DialectAsmParser &parser) { + std::string value; + if (parser.parseString(&value)) + return failure(); + return value; + } +}; + +/// Parse any container that supports back insertion as a list. +template +struct FieldParser< + ContainerT, std::enable_if_t::value, + ContainerT>> { + using ElementT = typename ContainerT::value_type; + static FailureOr parse(DialectAsmParser &parser) { + ContainerT elements; + auto elementParser = [&]() { + auto element = FieldParser::parse(parser); + if (failed(element)) + return failure(); + elements.push_back(element.getValue()); + return success(); + }; + if (parser.parseCommaSeparatedList(elementParser)) + return failure(); + return elements; + } +}; + } // end namespace mlir -#endif +#endif // MLIR_IR_DIALECTIMPLEMENTATION_H diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2838,6 +2838,11 @@ code printer = ?; code parser = ?; + // Custom assembly format. Requires 'mnemonic' to be specified. Cannot be + // specified at the same time as either 'printer' or 'parser'. The generated + // printer requires 'genAccessors' to be true. + string assemblyFormat = ?; + // If set, generate accessors for each parameter. bit genAccessors = 1; @@ -2916,10 +2921,22 @@ string cppType = type; // The C++ type of the accessor for this parameter. string cppAccessorType = !if(!empty(accessorType), type, accessorType); + // The C++ storage type of of this parameter if it is a reference, e.g. + // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`. + string cppStorageType = ?; // One-line human-readable description of the argument. string summary = desc; // The format string for the asm syntax (documentation only). string syntax = ?; + // The default parameter parser is `::mlir::parseField($_parser)`, which + // returns `FailureOr`. Overload `parseField` to support parsing for your + // type. Or you can provide a customer printer. For attributes, "$_type" will + // be replaced with the required attribute type. + string parser = ?; + // The default parameter printer is `$_printer << $_self`. Overload the stream + // operator of `DialectAsmPrinter` as necessary to print your type. Or you can + // provide a custom printer. + string printer = ?; } class AttrParameter : AttrOrTypeParameter; @@ -2930,6 +2947,8 @@ class StringRefParameter : AttrOrTypeParameter<"::llvm::StringRef", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; + let printer = [{$_printer << '"' << $_self << '"';}]; + let cppStorageType = "std::string"; } // For APFloats, which require comparison. @@ -2942,6 +2961,7 @@ class ArrayRefParameter : AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; + let cppStorageType = "::llvm::SmallVector<" # arrayOf # ">"; } // For classes which require allocation and have their own allocateInto method. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -178,10 +178,10 @@ llvm::interleaveComma(types, p); return p; } -template +template inline std::enable_if_t::value, AsmPrinterT &> -operator<<(AsmPrinterT &p, ArrayRef types) { +operator<<(AsmPrinterT &p, ArrayRef types) { llvm::interleaveComma(types, p); return p; } diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -101,6 +101,9 @@ // None. Otherwise, returns the contents of that code block. Optional getParserCode() const; + // Returns the custom assembly format, if one was specified. + Optional getAssemblyFormat() const; + // Returns true if the accessors based on the parameters should be generated. bool genAccessors() const; @@ -199,6 +202,15 @@ // Get the C++ accessor type of this parameter. StringRef getCppAccessorType() const; + // Get the C++ storage type of this parameter. + StringRef getCppStorageType() const; + + // Get an optional C++ parameter parser. + Optional getParser() const; + + // Get an optional C++ parameter printer. + Optional getPrinter() const; + // Get a description of this parameter for documentation purposes. Optional getSummary() const; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -1,3 +1,4 @@ +//===- Dialect.h - Dialect class --------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -132,6 +132,10 @@ return def->getValueAsOptionalString("parser"); } +Optional AttrOrTypeDef::getAssemblyFormat() const { + return def->getValueAsOptionalString("assemblyFormat"); +} + bool AttrOrTypeDef::genAccessors() const { return def->getValueAsBit("genAccessors"); } @@ -219,6 +223,32 @@ return getCppType(); } +StringRef AttrOrTypeParameter::getCppStorageType() const { + if (auto *param = dyn_cast(def->getArg(index))) { + if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType")) + return *type; + } + return getCppType(); +} + +Optional AttrOrTypeParameter::getParser() const { + auto *parameterType = def->getArg(index); + if (auto *param = dyn_cast(parameterType)) { + if (auto parser = param->getDef()->getValueAsOptionalString("parser")) + return *parser; + } + return {}; +} + +Optional AttrOrTypeParameter::getPrinter() const { + auto *parameterType = def->getArg(index); + if (auto *param = dyn_cast(parameterType)) { + if (auto printer = param->getDef()->getValueAsOptionalString("printer")) + return *printer; + } + return {}; +} + Optional AttrOrTypeParameter::getSummary() const { auto *parameterType = def->getArg(index); if (auto *param = dyn_cast(parameterType)) { diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -101,4 +101,37 @@ let genVerifyDecl = 1; } +def TestParamOne : AttrParameter<"int64_t", ""> {} + +def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> { + let printer = "$_printer << '\"' << $_self << '\"'"; +} + +def TestParamFour : ArrayRefParameter<"int", ""> { + let cppStorageType = "llvm::SmallVector"; + let parser = "::parseIntArray($_parser)"; + let printer = "::printIntArray($_printer, $_self)"; +} + +def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> { + let parameters = ( + ins + TestParamOne:$one, + TestParamTwo:$two, + "::mlir::IntegerAttr":$three, + TestParamFour:$four + ); + + let mnemonic = "attr_with_format"; + let assemblyFormat = "`<` $one `:` struct($two, $four) `:` $three `>`"; + let genVerifyDecl = 1; +} + +def TestAttrUgly : Test_Attr<"TestAttrUgly"> { + let parameters = (ins "::mlir::Attribute":$attr); + + let mnemonic = "attr_ugly"; + let assemblyFormat = "`begin` $attr `end`"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -16,9 +16,11 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/bit.h" using namespace mlir; using namespace test; @@ -127,6 +129,36 @@ return success(); } +LogicalResult +TestAttrWithFormatAttr::verify(function_ref emitError, + int64_t one, std::string two, IntegerAttr three, + ArrayRef four) { + if (four.size() != static_cast(one)) + return emitError() << "expected 'one' to equal 'four.size()'"; + return success(); +} + +//===----------------------------------------------------------------------===// +// Utility Functions for Generated Attributes +//===----------------------------------------------------------------------===// + +static FailureOr> parseIntArray(DialectAsmParser &parser) { + SmallVector ints; + if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() { + ints.push_back(0); + return parser.parseInteger(ints.back()); + }) || + parser.parseRSquare()) + return failure(); + return ints; +} + +static void printIntArray(DialectAsmPrinter &printer, ArrayRef ints) { + printer << '['; + llvm::interleaveComma(ints, printer); + printer << ']'; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -15,6 +15,7 @@ // To get the test dialect def. include "TestOps.td" +include "TestAttrDefs.td" include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" @@ -189,4 +190,44 @@ let mnemonic = "test_type_with_trait"; } +// Type with assembly format. +def TestTypeWithFormat : Test_Type<"TestTypeWithFormat"> { + let parameters = ( + ins + TestParamOne:$one, + TestParamTwo:$two, + "::mlir::Attribute":$three + ); + + let mnemonic = "type_with_format"; + let assemblyFormat = "`<` $one `,` struct($three, $two) `>`"; +} + +// Test dispatch to parseField +def TestTypeNoParser : Test_Type<"TestTypeNoParser"> { + let parameters = ( + ins + "uint32_t":$one, + ArrayRefParameter<"int64_t">:$two, + StringRefParameter<>:$three, + "::test::CustomParam":$four + ); + + let mnemonic = "no_parser"; + let assemblyFormat = "`<` $one `,` `[` $two `]` `,` $three `,` $four `>`"; +} + +def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> { + let parameters = ( + ins + "int":$v0, + "int":$v1, + "int":$v2, + "int":$v3 + ); + + let mnemonic = "struct_capture_all"; + let assemblyFormat = "`<` struct(*) `>`"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -38,8 +38,38 @@ } }; +/// A custom type for a test type parameter. +struct CustomParam { + int value; + + bool operator==(const CustomParam &other) const { + return other.value == value; + } +}; + +inline llvm::hash_code hash_value(const test::CustomParam ¶m) { + return llvm::hash_value(param.value); +} + } // namespace test +namespace mlir { +template <> +struct FieldParser { + static FailureOr parse(DialectAsmParser &parser) { + auto value = FieldParser::parse(parser); + if (failed(value)) + return failure(); + return test::CustomParam{value.getValue()}; + } +}; +} // end namespace mlir + +inline mlir::DialectAsmPrinter &operator<<(mlir::DialectAsmPrinter &printer, + const test::CustomParam ¶m) { + return printer << param.value; +} + #include "TestTypeInterfaces.h.inc" #define GET_TYPEDEF_CLASSES @@ -52,17 +82,19 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage { using KeyTy = ::llvm::StringRef; - explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key), body(::mlir::Type()) {} + explicit TestRecursiveTypeStorage(::llvm::StringRef key) + : name(key), body(::mlir::Type()) {} bool operator==(const KeyTy &other) const { return name == other; } - static TestRecursiveTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, - const KeyTy &key) { + static TestRecursiveTypeStorage * + construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) { return new (allocator.allocate()) TestRecursiveTypeStorage(allocator.copyInto(key)); } - ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, ::mlir::Type newBody) { + ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, + ::mlir::Type newBody) { // Cannot set a different body than before. if (body && body != newBody) return ::mlir::failure(); @@ -79,11 +111,13 @@ /// type, potentially itself. This requires the body to be mutated separately /// from type creation. class TestRecursiveType - : public ::mlir::Type::TypeBase { + : public ::mlir::Type::TypeBase { public: using Base::Base; - static TestRecursiveType get(::mlir::MLIRContext *ctx, ::llvm::StringRef name) { + static TestRecursiveType get(::mlir::MLIRContext *ctx, + ::llvm::StringRef name) { return Base::get(ctx, name); } diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -0,0 +1,69 @@ +// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include -asmformat-error-is-fatal=false %s 2>&1 | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "TestDialect"; + let cppNamespace = "::test"; +} + +class InvalidType : TypeDef { + let mnemonic = asm; +} + +/// Test format is missing a parameter capture. +def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> { + let parameters = (ins "int":$v0, "int":$v1); + // CHECK: format is missing reference to parameter: v1 + let assemblyFormat = "`<` $v0 `>`"; +} + +/// Test format has duplicate parameter captures. +def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> { + let parameters = (ins "int":$v0, "int":$v1); + // CHECK: duplicate parameter 'v0' + let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`"; +} + +/// Test format has invalid syntax. +def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> { + let parameters = (ins "int":$v0, "int":$v1); + // CHECK: expected literal, directive, or variable + let assemblyFormat = "`<` $v0, $v1 `>`"; +} + +/// Test struct directive has invalid syntax. +def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> { + let parameters = (ins "int":$v0); + // CHECK: literals may only be used in the top-level section of the format + // CHECK: expected a variable in `struct` argument list + let assemblyFormat = "`<` struct($v0, `,`) `>`"; +} + +/// Test struct directive cannot capture zero parameters. +def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> { + let parameters = (ins "int":$v0); + // CHECK: directive `struct` expects at least one variable + let assemblyFormat = "`<` struct() $v0 `>`"; +} + +/// Test capture parameter that does not exist. +def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> { + let parameters = (ins "int":$v0); + // CHECK: InvalidTypeF has no parameter named 'v1' + let assemblyFormat = "`<` $v0 $v1 `>`"; +} + +/// Test duplicate capture of parameter in capture-all struct. +def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> { + let parameters = (ins "int":$v0, "int":$v1, "int":$v2); + // CHECK: duplicate parameter 'v0' + let assemblyFormat = "`<` struct(*) $v0 `>`"; +} + +/// Test capture-all struct duplicate capture. +def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> { + let parameters = (ins "int":$v0, "int":$v1, "int":$v2); + // CHECK: struct(*) captures duplicate parameter: v0 + let assemblyFormat = "`<` $v0 struct(*) `>`"; +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: @test_roundtrip_parameter_parsers +// CHECK: !test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo"> +// CHECK: !test.type_with_format<2147, three = "hi", two = "hi"> +func private @test_roundtrip_parameter_parsers(!test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">) -> !test.type_with_format<2147, two = "hi", three = "hi"> +attributes { + // CHECK: #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64> + attr0 = #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>, + // CHECK: #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8> + attr1 = #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>, + // CHECK: #test<"attr_ugly begin 5 : index end"> + attr2 = #test<"attr_ugly begin 5 : index end"> +} + +// CHECK-LABEL: @test_roundtrip_default_parsers_struct +// CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> +// CHECK: !test.struct_capture_all +func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.mlir b/mlir/test/mlir-tblgen/attr-or-type-format.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attr-or-type-format.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt --split-input-file %s --verify-diagnostics + +func private @test_ugly_attr_cannot_be_pretty() -> () attributes { + attr = #test.attr_ugly // expected-error {{expected 'begin'}} +} + +// ----- + +func private @test_ugly_attr_no_mnemonic() -> () attributes { + attr = #test<""> // expected-error {{expected valid keyword}} +} + +// ----- + +func private @test_ugly_attr_parser_dispatch() -> () attributes { + attr = #test<"attr_ugly"> // expected-error {{expected 'begin'}} +} + +// ----- + +func private @test_ugly_attr_missing_parameter() -> () attributes { + // expected-error@+1 {{failed to parse TestAttrUgly parameter 'attr'}} + attr = #test<"attr_ugly begin"> // expected-error {{expected non-function type}} +} + +// ----- + +func private @test_ugly_attr_missing_literal() -> () attributes { + attr = #test<"attr_ugly begin \"string_attr\""> // expected-error {{expected 'end'}} +} + +// ----- + +func private @test_pretty_attr_expects_less() -> () attributes { + attr = #test.attr_with_format // expected-error {{expected '<'}} +} + +// ----- + +func private @test_pretty_attr_missing_param() -> () attributes { + // expected-error@+1 {{expected integer value}} + attr = #test.attr_with_format<> // expected-error {{failed to parse TestAttrWithFormat parameter 'one'}} +} + +// ----- + +func private @test_parse_invalid_param() -> () attributes { + // Test parameter parser failure is propagated + // expected-error@+1 {{expected integer value}} + attr = #test.attr_with_format<"hi"> // expected-error {{failed to parse TestAttrWithFormat parameter 'one'}} +} + +// ----- + +func private @test_pretty_attr_invalid_syntax() -> () attributes { + attr = #test.attr_with_format<42> // expected-error {{expected ':'}} +} + +// ----- + +func private @test_struct_missing_key() -> () attributes { + // expected-error@+1 {{expected valid keyword}} + attr = #test.attr_with_format<42 :> // expected-error {{expected a parameter name in struct}} +} + +// ----- + +func private @test_struct_unknown_key() -> () attributes { + attr = #test.attr_with_format<42 : nine = "foo"> // expected-error {{duplicate or unknown struct parameter}} +} + +// ----- + +func private @test_struct_duplicate_key() -> () attributes { + attr = #test.attr_with_format<42 : two = "foo", two = "bar"> // expected-error {{duplicate or unknown struct parameter}} +} + +// ----- + +func private @test_struct_not_enough_values() -> () attributes { + attr = #test.attr_with_format<42 : two = "foo"> // expected-error {{expected ','}} +} + +// ----- + +func private @test_parse_param_after_struct() -> () attributes { + // expected-error@+1 {{expected non-function type}} + attr = #test.attr_with_format<42 : two = "foo", four = [1, 2, 3] : > // expected-error {{failed to parse TestAttrWithFormat parameter 'three'}} +} + +// ----- + +func private @test_invalid_type() -> !test.type_with_format // expected-error {{expected '<'}} + +// ----- + +// expected-error@+1 {{expected integer value}} +func private @test_pretty_type_invalid_param() -> !test.type_with_format<> // expected-error {{failed to parse TestTypeWithFormat parameter 'one'}} + +// ----- + +// expected-error@+1 {{expected ':'}} +func private @test_type_syntax_error() -> !test.type_with_format<42, two = "hi", three = #test.attr_with_format<42>> // expected-error {{failed to parse TestTypeWithFormat parameter 'three'}} + +// ----- +func private @test_verifier_fails() -> () attributes { + attr = #test.attr_with_format<42 : two = "hello", four = [1, 2, 3] : 42 : i64> // expected-error {{expected 'one' to equal 'four.size()'}} +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -0,0 +1,250 @@ +// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR +// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE + +include "mlir/IR/OpBase.td" + +/// Test that attribute and type printers and parsers are correctly generated. +def Test_Dialect : Dialect { + let name = "TestDialect"; + let cppNamespace = "::test"; +} + +class TestAttr : AttrDef; +class TestType : TypeDef; + +def AttrParamA : AttrParameter<"TestParamA", "an attribute param A"> { + let parser = "::parseAttrParamA($_parser, $_type)"; + let printer = "::printAttrParamA($_printer, $_self)"; +} + +def AttrParamB : AttrParameter<"TestParamB", "an attribute param B"> { + let parser = "$_type ? ::parseAttrWithType($_parser, $_type) : ::parseAttrWithout($_parser)"; + let printer = "::printAttrB($_printer, $_self)"; +} + +def TypeParamA : TypeParameter<"TestParamC", "a type param C"> { + let parser = "::parseTypeParamC($_parser)"; + let printer = "$_printer << $_self"; +} + +def TypeParamB : TypeParameter<"TestParamD", "a type param D"> { + let parser = "someFcnCall()"; + let printer = "myPrinter($_self)"; +} + +// ATTR: ::mlir::Attribute TestAAttr::parse +// ATTR: ::mlir::Type +// ATTR: FailureOr +// ATTR: FailureOr +// ATTR: parseKeyword("hello") +// ATTR: parseEqual() +// ATTR: value = ::mlir::FieldParser::parse(parser) +// ATTR: parseComma() +// ATTR: complex = ::parseAttrParamA(parser, attrType) +// ATTR: parseRParen() +// ATTR: TestAAttr::get( +// ATTR: value +// ATTR: complex + +// ATTR: void TestAAttr::print +// ATTR: << ' ' << "hello" +// ATTR: << ' ' << "=" +// ATTR: << ' ' +// ATTR: << getValue() +// ATTR: << "," +// ATTR: << ' ' +// ATTR: ::printAttrParamA(printer, getComplex()) +// ATTR: << ")" +def AttrA : TestAttr<"TestA"> { + let parameters = (ins + "IntegerAttr":$value, + AttrParamA:$complex + ); + + let mnemonic = "attr_a"; + let assemblyFormat = "`hello` `=` $value `,` $complex `)`"; +} + +// ATTR: ::mlir::Attribute TestBAttr::parse +// ATTR: ::mlir::Type +// ATTR: parseLBrace() +// ATTR: for (unsigned _index = 0; _index < 4; ++_index) +// ATTR: parseKeyword(&_paramKey) +// ATTR: parseEqual() +// ATTR: v0 = ::parseAttrParamA(parser, attrType) +// ATTR: v1 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser) +// ATTR: v2 = ::parseAttrParamA(parser, attrType) +// ATTR: v3 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser) +// ATTR: parseComma() +// ATTR: parseRBrace() +// ATTR: TestBAttr::get +// ATTR: v0.getValue() +// ATTR: v1.getValue() +// ATTR: v2.getValue() +// ATTR: v3.getValue() + +// ATTR: void TestBAttr::print +// ATTR: << "{"; +// ATTR: << "v0"; +// ATTR: << ' ' << "="; +// ATTR: << ' '; +// ATTR: ::printAttrParamA(printer, getV0()); +// ATTR: << ","; +// ATTR: << ' ' << "v1"; +// ATTR: << ' ' << "="; +// ATTR: << ' '; +// ATTR: ::printAttrB(printer, getV1()); +// ATTR: << ","; +// ATTR: << ' ' << "v2"; +// ATTR: << ' ' << "="; +// ATTR: << ' '; +// ATTR: ::printAttrParamA(printer, getV2()); +// ATTR: << ","; +// ATTR: << ' ' << "v3"; +// ATTR: << ' ' << "="; +// ATTR: << ' '; +// ATTR: ::printAttrB(printer, getV3()); +// ATTR: << "}"; +def AttrB : TestAttr<"TestB"> { + let parameters = (ins + AttrParamA:$v0, + AttrParamB:$v1, + AttrParamA:$v2, + AttrParamB:$v3 + ); + + let mnemonic = "attr_b"; + let assemblyFormat = "`{` struct($v0, $v1, $v2, $v3) `}`"; +} + +// TYPE: ::mlir::Type TestCType::parse +// TYPE: parseKeyword("foo")) +// TYPE: parseComma()) +// TYPE: parseColon()) +// TYPE: parseKeyword("bob")) +// TYPE: parseKeyword("bar")) +// TYPE: ::mlir::FieldParser::parse(parser) +// TYPE: _index < 1 +// TYPE: ::parseTypeParamC +// TYPE: parseComma() +// TYPE: parseRParen() + +// TYPE: void TestCType::print +// TYPE: << ' ' << "foo" +// TYPE: << "," +// TYPE: << ' ' << ":" +// TYPE: << ' ' << "bob" +// TYPE: << ' ' << "bar" +// TYPE: << ' ' +// TYPE: << getValue() +// TYPE: << ' ' +// TYPE: << "complex" +// TYPE: << ' ' << "=" +// TYPE: << ' ' +// TYPE: << getComplex() +// TYPE: << ")" +def TypeA : TestType<"TestC"> { + let parameters = (ins + "IntegerAttr":$value, + TypeParamA:$complex + ); + + let mnemonic = "type_c"; + let assemblyFormat = "`foo` `,` `:` `bob` `bar` $value struct($complex) `)`"; +} + +// TYPE: ::mlir::Type TestDType::parse +// TYPE: parseLess() +// TYPE: parseKeyword("foo") +// TYPE: parseColon() +// TYPE: ::parseTypeParamC(parser) +// TYPE: parseComma() +// TYPE: _index < 2 +// TYPE: someFcnCall() +// TYPE: ::parseTypeParamC(parser) +// TYPE: parseComma() +// TYPE: parseComma() +// TYPE: someFcnCall() +// TYPE: parseGreater() + +// TYPE: TestDType::print +// TYPE: << "<" +// TYPE: << "foo" +// TYPE: << ' ' << ": +// TYPE: << ' ' +// TYPE: << getV0() +// TYPE: << "," +// TYPE: << ' ' +// TYPE: << "v1" +// TYPE: << ' ' << "=" +// TYPE: << ' ' +// TYPE: myPrinter(getV1()) +// TYPE: << "," +// TYPE: << ' ' << "v2" +// TYPE: << ' ' << "=" +// TYPE: << ' ' +// TYPE: << getV2() +// TYPE: << "," +// TYPE: << ' ' +// TYPE: myPrinter(getV3()) +// TYPE: << ">" +def TypeB : TestType<"TestD"> { + let parameters = (ins + TypeParamA:$v0, + TypeParamB:$v1, + TypeParamA:$v2, + TypeParamB:$v3 + ); + + let mnemonic = "type_d"; + let assemblyFormat = "`<` `foo` `:` $v0 `,` struct($v1, $v2) `,` $v3 `>`"; +} + +// TYPE: ::mlir::Type TestEType::parse +// TYPE: parseLBrace() +// TYPE: _index < 2 +// TYPE: parseKeyword +// TYPE: parseEqual() +// TYPE: ::mlir::FieldParser::parse(parser) +// TYPE: ::mlir::FieldParser::parse(parser) +// TYPE: parseComma() +// TYPE: parseRBrace() +// TYPE: parseLBrace() +// TYPE: _index < 2 +// TYPE: parseKeyword +// TYPE: parseEqual() +// TYPE: ::mlir::FieldParser::parse(parser) +// TYPE: ::mlir::FieldParser::parse(parser) +// TYPE: parseComma() +// TYPE: parseRBrace() + +// TYPE: TestEType::print +// TYPE: << "{" +// TYPE: << "v0" +// TYPE: << "=" +// TYPE: << getV0() +// TYPE: << "," +// TYPE: << "v2" +// TYPE: << "=" +// TYPE: << getV2() +// TYPE: << "}" +// TYPE: << "{" +// TYPE: << "v1" +// TYPE: << "=" +// TYPE: << getV1() +// TYPE: << "," +// TYPE: << "v3" +// TYPE: << "=" +// TYPE: << getV3() +// TYPE: << "}" +def TypeC : TestType<"TestE"> { + let parameters = (ins + "IntegerAttr":$v0, + "IntegerAttr":$v1, + "IntegerAttr":$v2, + "IntegerAttr":$v3 + ); + + let mnemonic = "type_e"; + let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`"; +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "AttrOrTypeFormatGen.h" #include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/CodeGenHelpers.h" @@ -399,7 +400,8 @@ << " }\n"; // If mnemonic specified, emit print/parse declarations. - if (def.getParserCode() || def.getPrinterCode() || !params.empty()) { + if (def.getParserCode() || def.getPrinterCode() || + def.getAssemblyFormat() || !params.empty()) { os << llvm::formatv(defDeclParsePrintStr, valueType, isAttrGenerator ? ", ::mlir::Type type" : ""); } @@ -410,10 +412,8 @@ def.getParameters(parameters); for (AttrOrTypeParameter ¶meter : parameters) { - SmallString<16> name = parameter.getName(); - name[0] = llvm::toUpper(name[0]); - os << formatv(" {0} get{1}() const;\n", parameter.getCppAccessorType(), - name); + os << formatv(" {0} {1}() const;\n", parameter.getCppAccessorType(), + getParameterAccessorName(parameter.getName())); } } @@ -700,8 +700,32 @@ } void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) { + auto printerCode = def.getPrinterCode(); + auto parserCode = def.getParserCode(); + auto assemblyFormat = def.getAssemblyFormat(); + if (assemblyFormat && (printerCode || parserCode)) { + // Custom assembly format cannot be specified at the same time as either + // custom printer or parser code. + PrintFatalError(def.getLoc(), + def.getName() + ": assembly format cannot be specified at " + "the same time as printer or parser code"); + } + + // Generate a parser and printer based on the assembly format, if specified. + if (assemblyFormat) { + // A custom assembly format requires accessors to be generated for the + // generated printer. + if (!def.genAccessors()) { + PrintFatalError(def.getLoc(), + def.getName() + + ": the generated printer from 'assemblyFormat' " + "requires 'genAccessors' to be true"); + } + return generateAttrOrTypeFormat(def, os); + } + // Emit the printer code, if specified. - if (Optional printerCode = def.getPrinterCode()) { + if (printerCode) { // Both the mnenomic and printerCode must be defined (for parity with // parserCode). os << "void " << def.getCppClassName() @@ -717,7 +741,7 @@ } // Emit the parser code, if specified. - if (Optional parserCode = def.getParserCode()) { + if (parserCode) { FmtContext fmtCtxt; fmtCtxt.addSubst("_parser", "parser") .addSubst("_ctxt", "parser.getContext()"); @@ -857,11 +881,10 @@ paramStorageName = param.getName(); } - SmallString<16> name = param.getName(); - name[0] = llvm::toUpper(name[0]); - os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n", - param.getCppAccessorType(), name, paramStorageName, - def.getCppClassName()); + os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n", + param.getCppAccessorType(), + getParameterAccessorName(param.getName()), + paramStorageName, def.getCppClassName()); } } } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h @@ -0,0 +1,29 @@ +//===- AttrOrTypeFormatGen.h - MLIR attribute and type format generator ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_ +#define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_ + +#include "llvm/Support/raw_ostream.h" + +#include + +namespace mlir { +namespace tblgen { +class AttrOrTypeDef; + +/// Generate a parser and printer based on a custom assembly format for an +/// attribute or type. +void generateAttrOrTypeFormat(const AttrOrTypeDef &def, llvm::raw_ostream &os); + +std::string getParameterAccessorName(llvm::StringRef name); + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_ diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -0,0 +1,725 @@ +//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "AttrOrTypeFormatGen.h" +#include "FormatGen.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/AttrOrTypeDef.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace mlir; +using namespace mlir::tblgen; + +using llvm::formatv; + +//===----------------------------------------------------------------------===// +// Element +//===----------------------------------------------------------------------===// + +namespace { + +/// This class represents a single format element. +class Element { +public: + /// LLVM-style RTTI. + enum class Kind { + /// This element is a directive. + StructDirective, + + /// This element is a literal. + Literal, + + /// This element is a variable. + Variable, + }; + Element(Kind kind) : kind(kind) {} + virtual ~Element() = default; + + /// Return the kind of this element. + Kind getKind() const { return kind; } + +private: + /// The kind of this element. + Kind kind; +}; + +/// This class represents an instance of a literal element. +class LiteralElement : public Element { +public: + LiteralElement(StringRef literal) + : Element(Kind::Literal), literal(literal) {} + + static bool classof(const Element *el) { + return el->getKind() == Kind::Literal; + } + + /// Get the literal spelling. + StringRef getSpelling() const { return literal; } + +private: + /// The spelling of the literal for this element. + StringRef literal; +}; + +/// This class represents an instance of a variable element. A variable refers +/// to an attribute or type parameter. +class VariableElement : public Element { +public: + VariableElement(AttrOrTypeParameter param) + : Element(Kind::Variable), param(param) {} + + static bool classof(const Element *el) { + return el->getKind() == Kind::Variable; + } + + /// Get the parameter in the element. + const AttrOrTypeParameter &getParam() const { return param; } + +private: + AttrOrTypeParameter param; +}; + +/// This class represents a `struct` directive that generates a struct format +/// of the form: +/// +/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}` +/// +class StructDirective : public Element { +public: + StructDirective(std::vector> &¶ms) + : Element(Kind::StructDirective), params(std::move(params)) {} + + static bool classof(const Element *el) { + return el->getKind() == Kind::StructDirective; + } + + /// Get the parameters contain in a struct. + auto getParams() { + return llvm::map_range(params, [](auto &el) { + return cast(el.get())->getParam(); + }); + } + + /// Get the number of parameters. + unsigned getNumParams() const { return params.size(); } + +private: + std::vector> params; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Format Strings +//===----------------------------------------------------------------------===// + +/// Format for defining an attribute parser. +/// +/// $0: The attribute C++ class name. +static const char *const attrParserDefn = R"( +::mlir::Attribute $0::parse(::mlir::DialectAsmParser &$_parser, + ::mlir::Type $_type) { +)"; + +/// Format for defining a type parser. +/// +/// $0: The type C++ class name. +static const char *const typeParserDefn = R"( +::mlir::Type $0::parse(::mlir::DialectAsmParser &$_parser) { +)"; + +/// Default parser for attribute or type parameters. +static const char *const defaultParameterParser = + "::mlir::FieldParser<$0>::parse($_parser)"; + +/// Default printer for attribute or type parameters. +static const char *const defaultParameterPrinter = "$_printer << $_self"; + +/// Print an error when failing to parse an element. +/// +/// $0: The parameter C++ class name. +static const char *const parseErrorStr = + "$_parser.emitError($_parser.getCurrentLocation(), "; + +/// Format for defining an attribute or type printer. +/// +/// $0: The attribute or type C++ class name. +/// $1: The attribute or type mnemonic. +static const char *const attrOrTypePrinterDefn = R"( +void $0::print(::mlir::DialectAsmPrinter &$_printer) const { + $_printer << "$1"; +)"; + +/// Loop declaration for struct parser. +/// +/// $0: Number of expected parameters. +static const char *const structParseLoopStart = R"( + for (unsigned _index = 0; _index < $0; ++_index) { + StringRef _paramKey; + if ($_parser.parseKeyword(&_paramKey)) { + $_parser.emitError($_parser.getCurrentLocation(), + "expected a parameter name in struct"); + return {}; + } +)"; + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +static auto getParameters(const AttrOrTypeDef &def) { + SmallVector params; + def.getParameters(params); + return params; +} + +//===----------------------------------------------------------------------===// +// AttrOrTypeFormat +//===----------------------------------------------------------------------===// + +namespace { + +class AttrOrTypeFormat { +public: + AttrOrTypeFormat(const AttrOrTypeDef &def, + std::vector> &&elements) + : def(def), elements(std::move(elements)) {} + + /// Generate the attribute or type parser. + void genParser(llvm::raw_ostream &os); + /// Generate the attribute or type printer. + void genPrinter(llvm::raw_ostream &os); + +private: + /// Generate the parser code for a specific format element. + void genElementParser(Element *el, FmtContext &ctx, llvm::raw_ostream &os); + /// Generate the parser code for a literal. + void genLiteralParser(StringRef value, FmtContext &ctx, llvm::raw_ostream &os, + unsigned indent = 0); + /// Generate the parser code for a variable. + void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx, + llvm::raw_ostream &os, unsigned indent = 0); + /// Generate the parser code for a `struct` directive. + void genStructParser(StructDirective *el, FmtContext &ctx, + llvm::raw_ostream &os); + /// Generate the printer code for a specific format element. + + void genElementPrinter(Element *el, FmtContext &ctx, llvm::raw_ostream &os); + /// Generate the printer code for a literal. + void genLiteralPrinter(StringRef value, FmtContext &ctx, + llvm::raw_ostream &os); + /// Generate the printer code for a variable. + void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx, + llvm::raw_ostream &os); + /// Generate the printer code for a `struct` directive. + void genStructPrinter(StructDirective *el, FmtContext &ctx, + llvm::raw_ostream &os); + + const AttrOrTypeDef &def; + std::vector> elements; + + /// Flags for printing spaces. + bool shouldEmitSpace; + bool lastWasPunctuation; +}; + +//===----------------------------------------------------------------------===// +// ParserGen +//===----------------------------------------------------------------------===// + +/// Generate the `AttrOrType::parse function definintion. +void AttrOrTypeFormat::genParser(llvm::raw_ostream &os) { + FmtContext ctx; + ctx.addSubst("_parser", "parser"); + + /// Generate the definition. + if (isa(def)) { + ctx.addSubst("_type", "attrType"); + os << tgfmt(attrParserDefn, &ctx, def.getCppClassName()); + } else { + os << tgfmt(typeParserDefn, &ctx, def.getCppClassName()); + } + + /// Declare variables to store all of the parameters. Allocated parameters + /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store + /// FailureOr to defer type construction for parameters that are parsed in + /// a loop (parsers return FailureOr anyways). + auto params = getParameters(def); + for (auto ¶m : params) { + os.indent(2) << formatv("::mlir::FailureOr<{0}> _result_{1};\n", + param.getCppStorageType(), param.getName()); + } + + /// Store the initial location of the parser. + ctx.addSubst("_loc", "loc"); + os.indent(2) << tgfmt( + "::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n", &ctx); + os.indent(2) << tgfmt("(void) $_loc;\n", &ctx); + + /// Generate call to each parameter parser. + for (auto &el : elements) + genElementParser(el.get(), ctx, os); + + /// Generate call to the attribute or type builder. Use the checked getter + /// if one was generated. + if (def.genVerifyDecl()) { + os.indent(2) << tgfmt( + "return $_parser.getChecked<$0>($_loc, $_parser.getContext()", &ctx, + def.getCppClassName()); + } else { + os.indent(2) << tgfmt("return $0::get($_parser.getContext()", &ctx, + def.getCppClassName()); + } + for (auto ¶m : params) { + os << ",\n"; + os.indent(4) << formatv("_result_{0}.getValue()", param.getName()); + } + os << ");\n"; + os << "}\n\n"; +} + +/// Generate a parser for a format element. +void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx, + llvm::raw_ostream &os) { + if (auto *literal = dyn_cast(el)) + return genLiteralParser(literal->getSpelling(), ctx, os); + if (auto *var = dyn_cast(el)) + return genVariableParser(var->getParam(), ctx, os); + if (auto *strct = dyn_cast(el)) + return genStructParser(strct, ctx, os); + + llvm_unreachable("unknown format element"); +} + +/// Generate a parser for a literal. +void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx, + llvm::raw_ostream &os, + unsigned indent) { + os.indent(indent + 2) << "// Parse literal '" << value << "'\n"; + os.indent(indent + 2) << tgfmt("if ($_parser.parse", &ctx); + if (value.front() == '_' || isalpha(value.front())) { + os << "Keyword(\"" << value << "\")"; + } else { + os << StringSwitch(value) + .Case("->", "Arrow") + .Case(":", "Colon") + .Case(",", "Comma") + .Case("=", "Equal") + .Case("<", "Less") + .Case(">", "Greater") + .Case("{", "LBrace") + .Case("}", "RBrace") + .Case("(", "LParen") + .Case(")", "RParen") + .Case("[", "LSquare") + .Case("]", "RSquare") + .Case("?", "Question") + .Case("+", "Plus") + .Case("*", "Star") + << "()"; + } + os << ")\n"; + // Parser will emit an error + os.indent(indent + 4) << "return {};\n"; +} + +/// Generate a parser for a variable. +void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m, + FmtContext &ctx, llvm::raw_ostream &os, + unsigned indent) { + /// Check for a custom parser. Use the default attribute parser otherwise. + auto name = param.getName(); + os.indent(indent + 2) << "// Parse variable '" << name << "'\n"; + os.indent(indent + 2) << formatv("_result_{0} = ", name); + auto parser = param.getParser(); + os << tgfmt(parser ? *parser : StringRef(defaultParameterParser), &ctx, + param.getCppStorageType()) + << ";\n"; + /// Print a result check. + os.indent(indent + 2) << formatv("if (failed(_result_{0})) {{\n", name); + os.indent(indent + 4) << tgfmt(parseErrorStr, &ctx) << "\"failed to parse " + << def.getName() << " parameter '" << name + << "' which is to be a `" << param.getCppType() + << "`\");\n"; + os.indent(indent + 4) << "return {};\n"; + os.indent(indent + 2) << "}\n"; +} + +/// Generate a struct parser for a list of parameters. +void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx, + llvm::raw_ostream &os) { + auto seen = [](const AttrOrTypeParameter ¶m) { + return formatv("_seen_{0}", param.getName()); + }; + os.indent(2) << "// Parse parameter struct\n"; + + /// Declare a "seen" variable for each key. + for (auto param : el->getParams()) + os.indent(2) << "bool " << seen(param) << " = false;\n"; + os << "\n"; + + /// Generate the parsing loop. + os << tgfmt(structParseLoopStart, &ctx, el->getNumParams()); + genLiteralParser("=", ctx, os, 2); + os.indent(4); + for (auto param : el->getParams()) { + os << "if (!" << seen(param) << " && _paramKey == \"" << param.getName() + << "\") {\n"; + os.indent(6) << seen(param) << " = true;\n"; + genVariableParser(param, ctx, os, 4); + os.indent(4) << "} else "; + } + + /// Duplicate or unknown parameter. + os << " {\n"; + os.indent(6) + << tgfmt(parseErrorStr, &ctx) + << "\"duplicate or unknown struct parameter name: \") << _paramKey;\n"; + os.indent(6) << "return {};\n"; + os.indent(4) << "}\n"; + + /// Parse comma except on last element. + os.indent(4) << "if (_index != " << el->getNumParams() << " - 1) {\n"; + genLiteralParser(",", ctx, os, 4); + os.indent(4) << "}\n"; + os.indent(2) << "}\n"; + + /// Because the loop loops N times and each non-failing iteration sets 1 of + /// N flags, successfully exiting the loop means that all parameters have been + /// seen. `parseOptionalComma` would cause issues with any formats that use + /// "struct(...) `,`" beacuse structs aren't sounded by braces. +} + +//===----------------------------------------------------------------------===// +// PrinterGen +//===----------------------------------------------------------------------===// + +void AttrOrTypeFormat::genPrinter(llvm::raw_ostream &os) { + FmtContext ctx; + ctx.addSubst("_printer", "printer"); + + /// Generate the definition. + os << tgfmt(attrOrTypePrinterDefn, &ctx, def.getCppClassName(), + *def.getMnemonic()); + + /// Generate printers. + shouldEmitSpace = true; + lastWasPunctuation = false; + for (auto &el : elements) + genElementPrinter(el.get(), ctx, os); + + os << "}\n\n"; +} + +void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx, + llvm::raw_ostream &os) { + if (auto *literal = dyn_cast(el)) + return genLiteralPrinter(literal->getSpelling(), ctx, os); + if (auto *strct = dyn_cast(el)) + return genStructPrinter(strct, ctx, os); + + /// Insert a space before the next element, if necessary. + if (shouldEmitSpace || !lastWasPunctuation) + os.indent(2) << tgfmt("$_printer << ' ';\n", &ctx); + shouldEmitSpace = true; + lastWasPunctuation = false; + + if (auto *var = dyn_cast(el)) + return genVariablePrinter(var->getParam(), ctx, os); + + llvm_unreachable("unknown format element"); +} + +void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, + llvm::raw_ostream &os) { + os.indent(2) << tgfmt("$_printer", &ctx); + + /// Don't insert a space before certain punctuation. + if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation)) + os << " << ' '"; + os << " << \"" << value << "\";\n"; + + /// Update the flags. + shouldEmitSpace = + value.size() != 1 || !StringRef("<({[").contains(value.front()); + lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); +} + +void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m, + FmtContext &ctx, + llvm::raw_ostream &os) { + ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); + if (auto printer = param.getPrinter()) { + os.indent(2) << tgfmt(*printer, &ctx) << ";\n"; + } else { + os.indent(2) << tgfmt(defaultParameterPrinter, &ctx) << ";\n"; + } +} + +void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, + llvm::raw_ostream &os) { + llvm::interleave( + el->getParams(), + [&](auto param) { + genLiteralPrinter(param.getName(), ctx, os); + genLiteralPrinter("=", ctx, os); + os.indent(2) << tgfmt("$_printer << ' ';\n", &ctx); + genVariablePrinter(param, ctx, os); + }, + [&]() { genLiteralPrinter(",", ctx, os); }); +} + +//===----------------------------------------------------------------------===// +// FormatParser +//===----------------------------------------------------------------------===// + +class FormatParser { +public: + FormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def) + : lexer(mgr, def.getLoc()[0]), def(def), + seenParams(def.getNumParameters()) {} + + /// Parse the attribute or type format and create the format elements. + FailureOr parse(); + +private: + /// The current context of the parser when parsing an element. + enum ParserContext { + /// The element is being parsed in the default context - at the top of the + /// format + TopLevelContext, + /// The element is being parsed as a child to a `struct` directive. + StructDirective, + }; + + /// Emit an error. + LogicalResult emitError(FormatToken tok, const Twine &msg) { + lexer.emitError(tok.getLoc(), msg); + return failure(); + } + + /// Parse an expected token. + FailureOr parseToken(FormatToken::Kind kind, const Twine &msg) { + auto tok = lexer.lexToken(); + if (tok.getKind() != kind) + return emitError(tok, msg); + return tok; + } + + using ParseResult = std::pair, FormatToken>; + + /// Parse any element. + FailureOr parseElement(FormatToken tok, ParserContext ctx); + /// Parse a literal element. + FailureOr parseLiteral(FormatToken tok, ParserContext ctx); + /// Parse a variable element. + FailureOr parseVariable(FormatToken tok, ParserContext ctx); + /// Parse a directive. + FailureOr parseDirective(FormatToken tok, ParserContext ctx); + /// Parse a `struct` directive. + FailureOr parseStructDirective(FormatToken tok); + + /// The current format lexer. + FormatLexer lexer; + /// Attribute or type tablegen def. + const AttrOrTypeDef &def; + + /// Seen attribute or type parameters. + llvm::BitVector seenParams; +}; + +FailureOr FormatParser::parse() { + std::vector> elements; + elements.reserve(16); + + /// Parse the format elements. + FormatToken tok = lexer.lexToken(); + while (tok.getKind() != FormatToken::eof) { + auto element = parseElement(tok, TopLevelContext); + if (failed(element)) + return failure(); + + /// Add the format element and continue. + elements.emplace_back(std::move(element->first)); + tok = element->second; + + if (tok.getKind() == FormatToken::error) + return emitError(tok, "unexpected format lex error"); + } + + /// Check that all parameters have been seen. + auto params = getParameters(def); + for (auto *it = params.begin(); it != params.end(); ++it) { + if (!seenParams.test(std::distance(params.begin(), it))) { + return emitError(tok, "format is missing reference to parameter: " + + it->getName()); + } + } + + return AttrOrTypeFormat(def, std::move(elements)); +} + +FailureOr +FormatParser::parseElement(FormatToken tok, ParserContext ctx) { + if (tok.getKind() == FormatToken::literal) + return parseLiteral(tok, ctx); + if (tok.getKind() == FormatToken::variable) + return parseVariable(tok, ctx); + if (tok.isKeyword()) + return parseDirective(tok, ctx); + + return emitError(tok, "expected literal, directive, or variable"); +} + +FailureOr +FormatParser::parseLiteral(FormatToken tok, ParserContext ctx) { + if (ctx != TopLevelContext) { + return emitError( + tok, + "literals may only be used in the top-level section of the format"); + } + + /// Get the literal spelling without the surrounding "`". + auto value = tok.getSpelling().drop_front().drop_back(); + if (!isValidLiteral(value)) + return emitError(tok, "literal '" + value + "' is not valid"); + + return ParseResult(std::make_unique(value), lexer.lexToken()); +} + +FailureOr +FormatParser::parseVariable(FormatToken tok, ParserContext ctx) { + /// Get the parameter name without the preceding "$". + auto name = tok.getSpelling().drop_front(); + + /// Lookup the parameter. + auto params = getParameters(def); + auto *it = llvm::find_if( + params, [&](auto ¶m) { return param.getName() == name; }); + + /// Check that the parameter reference is valid. + if (it == params.end()) { + return emitError(tok, + def.getName() + " has no parameter named '" + name + "'"); + } + auto idx = std::distance(params.begin(), it); + if (seenParams.test(idx)) + return emitError(tok, "duplicate parameter '" + name + "'"); + seenParams.set(idx); + + return ParseResult(std::make_unique(*it), lexer.lexToken()); +} + +FailureOr +FormatParser::parseDirective(FormatToken tok, ParserContext ctx) { + if (ctx != TopLevelContext) { + return emitError( + tok, + "directives may only be used in the top-level section of the format"); + } + + switch (tok.getKind()) { + case FormatToken::kw_struct: + return parseStructDirective(tok); + default: + return emitError(tok, "unknown directive in format: " + tok.getSpelling()); + } +} + +FailureOr +FormatParser::parseStructDirective(FormatToken tok) { + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before `struct` argument list"))) + return failure(); + + /// Parse variables captured by `struct`. + std::vector> vars; + vars.reserve(8); + + tok = lexer.lexToken(); + if (tok.getKind() == FormatToken::star) { + /// `struct(*)` captures all parameters in the attribute or type. + auto params = getParameters(def); + for (auto *it = params.begin(); it != params.end(); ++it) { + auto idx = std::distance(params.begin(), it); + if (seenParams.test(idx)) { + return emitError(tok, "struct(*) captures duplicate parameter: " + + it->getName()); + } + seenParams.set(idx); + vars.push_back(std::make_unique(*it)); + } + tok = lexer.lexToken(); + } else { + /// Parse first captured parameter. + auto var = parseElement(tok, StructDirective); + if (failed(var) || !isa(var->first)) + return emitError(tok, "directive `struct` expects at least one variable"); + + /// Parse any other parameters. + vars.push_back(std::move(var->first)); + for (tok = var->second; tok.getKind() == FormatToken::comma; + tok = var->second) { + var = parseElement(lexer.lexToken(), StructDirective); + if (failed(var) || !isa(var->first)) + return emitError(tok, "expected a variable in `struct` argument list"); + vars.push_back(std::move(var->first)); + } + } + + if (tok.getKind() != FormatToken::r_paren) + return emitError(tok, "expected ')' at the end of an argument list"); + + return ParseResult(std::make_unique<::StructDirective>(std::move(vars)), + lexer.lexToken()); +} + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +std::string mlir::tblgen::getParameterAccessorName(StringRef name) { + assert(!name.empty() && "parameter has empty name"); + auto ret = "get" + name.str(); + ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name + return ret; +} + +//===----------------------------------------------------------------------===// +// Interface +//===----------------------------------------------------------------------===// + +void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def, + llvm::raw_ostream &os) { + llvm::SourceMgr mgr; + mgr.AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), + llvm::SMLoc()); + + /// Parse the custom assembly format> + FormatParser parser(mgr, def); + auto format = parser.parse(); + if (failed(format)) { + if (format::formatErrorIsFatal) + PrintFatalError(def.getLoc(), "failed to parse assembly format"); + return; + } + + /// Generate the parser and printer. + format->genParser(os); + format->genPrinter(os); +} diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -6,10 +6,12 @@ add_tablegen(mlir-tblgen MLIR AttrOrTypeDefGen.cpp + AttrOrTypeFormatGen.cpp CodeGenHelpers.cpp DialectGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp + FormatGen.cpp LLVMIRConversionGen.cpp LLVMIRIntrinsicGen.cpp mlir-tblgen.cpp diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -0,0 +1,162 @@ +//===- FormatGen.h - Utilities for custom assembly formats ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains common classes for building custom assembly format parsers +// and generators. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_ +#define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/SMLoc.h" + +namespace llvm { +class SourceMgr; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +//===----------------------------------------------------------------------===// +// FormatToken +//===----------------------------------------------------------------------===// + +/// This class represents a specific token in the input format. +class FormatToken { +public: + /// Basic token kinds. + enum Kind { + // Markers. + eof, + error, + + // Tokens with no info. + l_paren, + r_paren, + caret, + colon, + comma, + equal, + less, + greater, + question, + star, + + // Keywords. + keyword_start, + kw_attr_dict, + kw_attr_dict_w_keyword, + kw_custom, + kw_functional_type, + kw_operands, + kw_ref, + kw_regions, + kw_results, + kw_successors, + kw_type, + kw_struct, + keyword_end, + + // String valued tokens. + identifier, + literal, + variable, + }; + + FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} + + /// Return the bytes that make up this token. + StringRef getSpelling() const { return spelling; } + + /// Return the kind of this token. + Kind getKind() const { return kind; } + + /// Return a location for this token. + llvm::SMLoc getLoc() const; + + /// Return if this token is a keyword. + bool isKeyword() const { + return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end; + } + +private: + /// Discriminator that indicates the kind of token this is. + Kind kind; + + /// A reference to the entire token contents; this is always a pointer into + /// a memory buffer owned by the source manager. + StringRef spelling; +}; + +//===----------------------------------------------------------------------===// +// FormatLexer +//===----------------------------------------------------------------------===// + +/// This class implements a simple lexer for operation assembly format strings. +class FormatLexer { +public: + FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc); + + /// Lex the next token and return it. + FormatToken lexToken(); + + /// Emit an error to the lexer with the given location and message. + FormatToken emitError(llvm::SMLoc loc, const Twine &msg); + FormatToken emitError(const char *loc, const Twine &msg); + + FormatToken emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, + const Twine ¬e); + +private: + /// Return the next character in the stream. + int getNextChar(); + + /// Lex an identifier, literal, or variable. + FormatToken lexIdentifier(const char *tokStart); + FormatToken lexLiteral(const char *tokStart); + FormatToken lexVariable(const char *tokStart); + + /// Create a token with the current pointer and a start pointer. + FormatToken formToken(FormatToken::Kind kind, const char *tokStart) { + return FormatToken(kind, StringRef(tokStart, curPtr - tokStart)); + } + + /// The source manager containing the format string. + llvm::SourceMgr &mgr; + /// Location of the format string. + llvm::SMLoc loc; + /// Buffer containing the format string. + StringRef curBuffer; + /// Current pointer in the buffer. + const char *curPtr; +}; + +/// Whether a space needs to be emitted before a literal. E.g., two keywords +/// back-to-back require a space separator, but a keyword followed by '<' does +/// not require a space. +bool shouldEmitSpaceBefore(StringRef value, bool lastWasPunctuation); + +/// Returns true if the given string can be formatted as a keyword. +bool canFormatStringAsKeyword(StringRef value); + +/// Returns true if the given string is valid format literal element. +bool isValidLiteral(StringRef value); + +/// Whether a failure in parsing the assembly format should be a fatal error. +namespace format { +extern llvm::cl::opt formatErrorIsFatal; +} // end namespace format + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_ diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -0,0 +1,224 @@ +//===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "FormatGen.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/TableGen/Error.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// FormatToken +//===----------------------------------------------------------------------===// + +llvm::SMLoc FormatToken::getLoc() const { + return llvm::SMLoc::getFromPointer(spelling.data()); +} + +//===----------------------------------------------------------------------===// +// FormatLexer +//===----------------------------------------------------------------------===// + +FormatLexer::FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc) + : mgr(mgr), loc(loc), + curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()), + curPtr(curBuffer.begin()) {} + +FormatToken FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) { + mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); + llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note, + "in custom assembly format for this operation"); + return formToken(FormatToken::error, loc.getPointer()); +} + +FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) { + return emitError(llvm::SMLoc::getFromPointer(loc), msg); +} + +FormatToken FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, + const Twine ¬e) { + mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); + llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note, + "in custom assembly format for this operation"); + mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note); + return formToken(FormatToken::error, loc.getPointer()); +} + +int FormatLexer::getNextChar() { + char curChar = *curPtr++; + switch (curChar) { + default: + return (unsigned char)curChar; + case 0: { + // A nul character in the stream is either the end of the current buffer or + // a random nul in the file. Disambiguate that here. + if (curPtr - 1 != curBuffer.end()) + return 0; + + // Otherwise, return end of file. + --curPtr; + return EOF; + } + case '\n': + case '\r': + // Handle the newline character by ignoring it and incrementing the line + // count. However, be careful about 'dos style' files with \n\r in them. + // Only treat a \n\r or \r\n as a single line. + if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) + ++curPtr; + return '\n'; + } +} + +FormatToken FormatLexer::lexToken() { + const char *tokStart = curPtr; + + // This always consumes at least one character. + int curChar = getNextChar(); + switch (curChar) { + default: + // Handle identifiers: [a-zA-Z_] + if (isalpha(curChar) || curChar == '_') + return lexIdentifier(tokStart); + + // Unknown character, emit an error. + return emitError(tokStart, "unexpected character"); + case EOF: + // Return EOF denoting the end of lexing. + return formToken(FormatToken::eof, tokStart); + + // Lex punctuation. + case '^': + return formToken(FormatToken::caret, tokStart); + case ':': + return formToken(FormatToken::colon, tokStart); + case ',': + return formToken(FormatToken::comma, tokStart); + case '=': + return formToken(FormatToken::equal, tokStart); + case '<': + return formToken(FormatToken::less, tokStart); + case '>': + return formToken(FormatToken::greater, tokStart); + case '?': + return formToken(FormatToken::question, tokStart); + case '(': + return formToken(FormatToken::l_paren, tokStart); + case ')': + return formToken(FormatToken::r_paren, tokStart); + case '*': + return formToken(FormatToken::star, tokStart); + + // Ignore whitespace characters. + case 0: + case ' ': + case '\t': + case '\n': + return lexToken(); + + case '`': + return lexLiteral(tokStart); + case '$': + return lexVariable(tokStart); + } +} + +FormatToken FormatLexer::lexLiteral(const char *tokStart) { + assert(curPtr[-1] == '`'); + + // Lex a literal surrounded by ``. + while (const char curChar = *curPtr++) { + if (curChar == '`') + return formToken(FormatToken::literal, tokStart); + } + return emitError(curPtr - 1, "unexpected end of file in literal"); +} + +FormatToken FormatLexer::lexVariable(const char *tokStart) { + if (!isalpha(curPtr[0]) && curPtr[0] != '_') + return emitError(curPtr - 1, "expected variable name"); + + // Otherwise, consume the rest of the characters. + while (isalnum(*curPtr) || *curPtr == '_') + ++curPtr; + return formToken(FormatToken::variable, tokStart); +} + +FormatToken FormatLexer::lexIdentifier(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* + while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') + ++curPtr; + + // Check to see if this identifier is a keyword. + StringRef str(tokStart, curPtr - tokStart); + auto kind = + StringSwitch(str) + .Case("attr-dict", FormatToken::kw_attr_dict) + .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword) + .Case("custom", FormatToken::kw_custom) + .Case("functional-type", FormatToken::kw_functional_type) + .Case("operands", FormatToken::kw_operands) + .Case("ref", FormatToken::kw_ref) + .Case("regions", FormatToken::kw_regions) + .Case("results", FormatToken::kw_results) + .Case("successors", FormatToken::kw_successors) + .Case("type", FormatToken::kw_type) + .Case("struct", FormatToken::kw_struct) + .Default(FormatToken::identifier); + return FormatToken(kind, str); +} + +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// + +bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value, + bool lastWasPunctuation) { + if (value.size() != 1 && value != "->") + return true; + if (lastWasPunctuation) + return !StringRef(">)}],").contains(value.front()); + return !StringRef("<>(){}[],").contains(value.front()); +} + +bool mlir::tblgen::canFormatStringAsKeyword(StringRef value) { + if (!isalpha(value.front()) && value.front() != '_') + return false; + return llvm::all_of(value.drop_front(), [](char c) { + return isalnum(c) || c == '_' || c == '$' || c == '.'; + }); +} + +bool mlir::tblgen::isValidLiteral(StringRef value) { + if (value.empty()) + return false; + char front = value.front(); + + // If there is only one character, this must either be punctuation or a + // single character bare identifier. + if (value.size() == 1) + return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front); + + // Check the punctuation that are larger than a single character. + if (value == "->") + return true; + + // Otherwise, this must be an identifier. + return canFormatStringAsKeyword(value); +} + +//===----------------------------------------------------------------------===// +// Commandline Options +//===----------------------------------------------------------------------===// + +llvm::cl::opt mlir::tblgen::format::formatErrorIsFatal( + "asmformat-error-is-fatal", + llvm::cl::desc("Emit a fatal error if format parsing fails"), + llvm::cl::init(true)); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "OpFormatGen.h" +#include "FormatGen.h" #include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" @@ -20,7 +21,6 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -30,20 +30,6 @@ using namespace mlir; using namespace mlir::tblgen; -static llvm::cl::opt formatErrorIsFatal( - "asmformat-error-is-fatal", - llvm::cl::desc("Emit a fatal error if format parsing fails"), - llvm::cl::init(true)); - -/// Returns true if the given string can be formatted as a keyword. -static bool canFormatStringAsKeyword(StringRef value) { - if (!isalpha(value.front()) && value.front() != '_') - return false; - return llvm::all_of(value.drop_front(), [](char c) { - return isalnum(c) || c == '_' || c == '$' || c == '.'; - }); -} - //===----------------------------------------------------------------------===// // Element //===----------------------------------------------------------------------===// @@ -273,33 +259,12 @@ /// Return the literal for this element. StringRef getLiteral() const { return literal; } - /// Returns true if the given string is a valid literal. - static bool isValidLiteral(StringRef value); - private: /// The spelling of the literal for this element. StringRef literal; }; } // end anonymous namespace -bool LiteralElement::isValidLiteral(StringRef value) { - if (value.empty()) - return false; - char front = value.front(); - - // If there is only one character, this must either be punctuation or a - // single character bare identifier. - if (value.size() == 1) - return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front); - - // Check the punctuation that are larger than a single character. - if (value == "->") - return true; - - // Otherwise, this must be an identifier. - return canFormatStringAsKeyword(value); -} - //===----------------------------------------------------------------------===// // WhitespaceElement @@ -1676,14 +1641,7 @@ body << " p"; // Don't insert a space for certain punctuation. - auto shouldPrintSpaceBeforeLiteral = [&] { - if (value.size() != 1 && value != "->") - return true; - if (lastWasPunctuation) - return !StringRef(">)}],").contains(value.front()); - return !StringRef("<>(){}[],").contains(value.front()); - }; - if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral()) + if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation)) body << " << ' '"; body << " << \"" << value << "\";\n"; @@ -2053,253 +2011,6 @@ lastWasPunctuation); } -//===----------------------------------------------------------------------===// -// FormatLexer -//===----------------------------------------------------------------------===// - -namespace { -/// This class represents a specific token in the input format. -class Token { -public: - enum Kind { - // Markers. - eof, - error, - - // Tokens with no info. - l_paren, - r_paren, - caret, - colon, - comma, - equal, - less, - greater, - question, - - // Keywords. - keyword_start, - kw_attr_dict, - kw_attr_dict_w_keyword, - kw_custom, - kw_functional_type, - kw_operands, - kw_ref, - kw_regions, - kw_results, - kw_successors, - kw_type, - keyword_end, - - // String valued tokens. - identifier, - literal, - variable, - }; - Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} - - /// Return the bytes that make up this token. - StringRef getSpelling() const { return spelling; } - - /// Return the kind of this token. - Kind getKind() const { return kind; } - - /// Return a location for this token. - llvm::SMLoc getLoc() const { - return llvm::SMLoc::getFromPointer(spelling.data()); - } - - /// Return if this token is a keyword. - bool isKeyword() const { return kind > keyword_start && kind < keyword_end; } - -private: - /// Discriminator that indicates the kind of token this is. - Kind kind; - - /// A reference to the entire token contents; this is always a pointer into - /// a memory buffer owned by the source manager. - StringRef spelling; -}; - -/// This class implements a simple lexer for operation assembly format strings. -class FormatLexer { -public: - FormatLexer(llvm::SourceMgr &mgr, Operator &op); - - /// Lex the next token and return it. - Token lexToken(); - - /// Emit an error to the lexer with the given location and message. - Token emitError(llvm::SMLoc loc, const Twine &msg); - Token emitError(const char *loc, const Twine &msg); - - Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine ¬e); - -private: - Token formToken(Token::Kind kind, const char *tokStart) { - return Token(kind, StringRef(tokStart, curPtr - tokStart)); - } - - /// Return the next character in the stream. - int getNextChar(); - - /// Lex an identifier, literal, or variable. - Token lexIdentifier(const char *tokStart); - Token lexLiteral(const char *tokStart); - Token lexVariable(const char *tokStart); - - llvm::SourceMgr &srcMgr; - Operator &op; - StringRef curBuffer; - const char *curPtr; -}; -} // end anonymous namespace - -FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op) - : srcMgr(mgr), op(op) { - curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer(); - curPtr = curBuffer.begin(); -} - -Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) { - srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); - llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note, - "in custom assembly format for this operation"); - return formToken(Token::error, loc.getPointer()); -} -Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, - const Twine ¬e) { - srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); - llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note, - "in custom assembly format for this operation"); - srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note); - return formToken(Token::error, loc.getPointer()); -} -Token FormatLexer::emitError(const char *loc, const Twine &msg) { - return emitError(llvm::SMLoc::getFromPointer(loc), msg); -} - -int FormatLexer::getNextChar() { - char curChar = *curPtr++; - switch (curChar) { - default: - return (unsigned char)curChar; - case 0: { - // A nul character in the stream is either the end of the current buffer or - // a random nul in the file. Disambiguate that here. - if (curPtr - 1 != curBuffer.end()) - return 0; - - // Otherwise, return end of file. - --curPtr; - return EOF; - } - case '\n': - case '\r': - // Handle the newline character by ignoring it and incrementing the line - // count. However, be careful about 'dos style' files with \n\r in them. - // Only treat a \n\r or \r\n as a single line. - if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) - ++curPtr; - return '\n'; - } -} - -Token FormatLexer::lexToken() { - const char *tokStart = curPtr; - - // This always consumes at least one character. - int curChar = getNextChar(); - switch (curChar) { - default: - // Handle identifiers: [a-zA-Z_] - if (isalpha(curChar) || curChar == '_') - return lexIdentifier(tokStart); - - // Unknown character, emit an error. - return emitError(tokStart, "unexpected character"); - case EOF: - // Return EOF denoting the end of lexing. - return formToken(Token::eof, tokStart); - - // Lex punctuation. - case '^': - return formToken(Token::caret, tokStart); - case ':': - return formToken(Token::colon, tokStart); - case ',': - return formToken(Token::comma, tokStart); - case '=': - return formToken(Token::equal, tokStart); - case '<': - return formToken(Token::less, tokStart); - case '>': - return formToken(Token::greater, tokStart); - case '?': - return formToken(Token::question, tokStart); - case '(': - return formToken(Token::l_paren, tokStart); - case ')': - return formToken(Token::r_paren, tokStart); - - // Ignore whitespace characters. - case 0: - case ' ': - case '\t': - case '\n': - return lexToken(); - - case '`': - return lexLiteral(tokStart); - case '$': - return lexVariable(tokStart); - } -} - -Token FormatLexer::lexLiteral(const char *tokStart) { - assert(curPtr[-1] == '`'); - - // Lex a literal surrounded by ``. - while (const char curChar = *curPtr++) { - if (curChar == '`') - return formToken(Token::literal, tokStart); - } - return emitError(curPtr - 1, "unexpected end of file in literal"); -} - -Token FormatLexer::lexVariable(const char *tokStart) { - if (!isalpha(curPtr[0]) && curPtr[0] != '_') - return emitError(curPtr - 1, "expected variable name"); - - // Otherwise, consume the rest of the characters. - while (isalnum(*curPtr) || *curPtr == '_') - ++curPtr; - return formToken(Token::variable, tokStart); -} - -Token FormatLexer::lexIdentifier(const char *tokStart) { - // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* - while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') - ++curPtr; - - // Check to see if this identifier is a keyword. - StringRef str(tokStart, curPtr - tokStart); - Token::Kind kind = - StringSwitch(str) - .Case("attr-dict", Token::kw_attr_dict) - .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword) - .Case("custom", Token::kw_custom) - .Case("functional-type", Token::kw_functional_type) - .Case("operands", Token::kw_operands) - .Case("ref", Token::kw_ref) - .Case("regions", Token::kw_regions) - .Case("results", Token::kw_results) - .Case("successors", Token::kw_successors) - .Case("type", Token::kw_type) - .Default(Token::identifier); - return Token(kind, str); -} - //===----------------------------------------------------------------------===// // FormatParser //===----------------------------------------------------------------------===// @@ -2318,8 +2029,8 @@ class FormatParser { public: FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op) - : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op), - seenOperandTypes(op.getNumOperands()), + : lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format), + op(op), seenOperandTypes(op.getNumOperands()), seenResultTypes(op.getNumResults()) {} /// Parse the operation assembly format. @@ -2421,7 +2132,8 @@ LogicalResult parseCustomDirectiveParameter( std::vector> ¶meters); LogicalResult parseFunctionalTypeDirective(std::unique_ptr &element, - Token tok, ParserContext context); + FormatToken tok, + ParserContext context); LogicalResult parseOperandsDirective(std::unique_ptr &element, llvm::SMLoc loc, ParserContext context); LogicalResult parseReferenceDirective(std::unique_ptr &element, @@ -2433,8 +2145,8 @@ LogicalResult parseSuccessorsDirective(std::unique_ptr &element, llvm::SMLoc loc, ParserContext context); - LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, - ParserContext context); + LogicalResult parseTypeDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context); LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, bool isRefChild = false); @@ -2444,12 +2156,12 @@ /// Advance the current lexer onto the next token. void consumeToken() { - assert(curToken.getKind() != Token::eof && - curToken.getKind() != Token::error && + assert(curToken.getKind() != FormatToken::eof && + curToken.getKind() != FormatToken::error && "shouldn't advance past EOF or errors"); curToken = lexer.lexToken(); } - LogicalResult parseToken(Token::Kind kind, const Twine &msg) { + LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) { if (curToken.getKind() != kind) return emitError(curToken.getLoc(), msg); consumeToken(); @@ -2470,7 +2182,7 @@ //===--------------------------------------------------------------------===// FormatLexer lexer; - Token curToken; + FormatToken curToken; OperationFormat &fmt; Operator &op; @@ -2490,7 +2202,7 @@ llvm::SMLoc loc = curToken.getLoc(); // Parse each of the format elements into the main format. - while (curToken.getKind() != Token::eof) { + while (curToken.getKind() != FormatToken::eof) { std::unique_ptr element; if (failed(parseElement(element, TopLevelContext))) return ::mlir::failure(); @@ -2804,13 +2516,13 @@ if (curToken.isKeyword()) return parseDirective(element, context); // Literals. - if (curToken.getKind() == Token::literal) + if (curToken.getKind() == FormatToken::literal) return parseLiteral(element, context); // Optionals. - if (curToken.getKind() == Token::l_paren) + if (curToken.getKind() == FormatToken::l_paren) return parseOptional(element, context); // Variables. - if (curToken.getKind() == Token::variable) + if (curToken.getKind() == FormatToken::variable) return parseVariable(element, context); return emitError(curToken.getLoc(), "expected directive, literal, variable, or optional group"); @@ -2818,7 +2530,7 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr &element, ParserContext context) { - Token varTok = curToken; + FormatToken varTok = curToken; consumeToken(); StringRef name = varTok.getSpelling().drop_front(); @@ -2898,31 +2610,31 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr &element, ParserContext context) { - Token dirTok = curToken; + FormatToken dirTok = curToken; consumeToken(); switch (dirTok.getKind()) { - case Token::kw_attr_dict: + case FormatToken::kw_attr_dict: return parseAttrDictDirective(element, dirTok.getLoc(), context, /*withKeyword=*/false); - case Token::kw_attr_dict_w_keyword: + case FormatToken::kw_attr_dict_w_keyword: return parseAttrDictDirective(element, dirTok.getLoc(), context, /*withKeyword=*/true); - case Token::kw_custom: + case FormatToken::kw_custom: return parseCustomDirective(element, dirTok.getLoc(), context); - case Token::kw_functional_type: + case FormatToken::kw_functional_type: return parseFunctionalTypeDirective(element, dirTok, context); - case Token::kw_operands: + case FormatToken::kw_operands: return parseOperandsDirective(element, dirTok.getLoc(), context); - case Token::kw_regions: + case FormatToken::kw_regions: return parseRegionsDirective(element, dirTok.getLoc(), context); - case Token::kw_results: + case FormatToken::kw_results: return parseResultsDirective(element, dirTok.getLoc(), context); - case Token::kw_successors: + case FormatToken::kw_successors: return parseSuccessorsDirective(element, dirTok.getLoc(), context); - case Token::kw_ref: + case FormatToken::kw_ref: return parseReferenceDirective(element, dirTok.getLoc(), context); - case Token::kw_type: + case FormatToken::kw_type: return parseTypeDirective(element, dirTok, context); default: @@ -2932,7 +2644,7 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr &element, ParserContext context) { - Token literalTok = curToken; + FormatToken literalTok = curToken; if (context != TopLevelContext) { return emitError( literalTok.getLoc(), @@ -2954,7 +2666,7 @@ } // Check that the parsed literal is valid. - if (!LiteralElement::isValidLiteral(value)) + if (!isValidLiteral(value)) return emitError(literalTok.getLoc(), "expected valid literal"); element = std::make_unique(value); @@ -2975,14 +2687,15 @@ do { if (failed(parseOptionalChildElement(thenElements, anchorIdx))) return ::mlir::failure(); - } while (curToken.getKind() != Token::r_paren); + } while (curToken.getKind() != FormatToken::r_paren); consumeToken(); // Parse the `else` elements of this optional group. - if (curToken.getKind() == Token::colon) { + if (curToken.getKind() == FormatToken::colon) { consumeToken(); - if (failed(parseToken(Token::l_paren, "expected '(' to start else branch " - "of optional group"))) + if (failed(parseToken(FormatToken::l_paren, + "expected '(' to start else branch " + "of optional group"))) return failure(); do { llvm::SMLoc childLoc = curToken.getLoc(); @@ -2991,11 +2704,12 @@ failed(verifyOptionalChildElement(elseElements.back().get(), childLoc, /*isAnchor=*/false))) return failure(); - } while (curToken.getKind() != Token::r_paren); + } while (curToken.getKind() != FormatToken::r_paren); consumeToken(); } - if (failed(parseToken(Token::question, "expected '?' after optional group"))) + if (failed(parseToken(FormatToken::question, + "expected '?' after optional group"))) return ::mlir::failure(); // The optional group is required to have an anchor. @@ -3030,7 +2744,7 @@ return ::mlir::failure(); // Check to see if this element is the anchor of the optional group. - bool isAnchor = curToken.getKind() == Token::caret; + bool isAnchor = curToken.getKind() == FormatToken::caret; if (isAnchor) { if (anchorIdx) return emitError(childLoc, "only one element can be marked as the anchor " @@ -3134,16 +2848,16 @@ return emitError(loc, "'custom' is only valid as a top-level directive"); // Parse the custom directive name. - if (failed( - parseToken(Token::less, "expected '<' before custom directive name"))) + if (failed(parseToken(FormatToken::less, + "expected '<' before custom directive name"))) return ::mlir::failure(); - Token nameTok = curToken; - if (failed(parseToken(Token::identifier, + FormatToken nameTok = curToken; + if (failed(parseToken(FormatToken::identifier, "expected custom directive name identifier")) || - failed(parseToken(Token::greater, + failed(parseToken(FormatToken::greater, "expected '>' after custom directive name")) || - failed(parseToken(Token::l_paren, + failed(parseToken(FormatToken::l_paren, "expected '(' before custom directive parameters"))) return ::mlir::failure(); @@ -3152,12 +2866,12 @@ do { if (failed(parseCustomDirectiveParameter(elements))) return ::mlir::failure(); - if (curToken.getKind() != Token::comma) + if (curToken.getKind() != FormatToken::comma) break; consumeToken(); } while (true); - if (failed(parseToken(Token::r_paren, + if (failed(parseToken(FormatToken::r_paren, "expected ')' after custom directive parameters"))) return ::mlir::failure(); @@ -3194,9 +2908,8 @@ return ::mlir::success(); } -LogicalResult -FormatParser::parseFunctionalTypeDirective(std::unique_ptr &element, - Token tok, ParserContext context) { +LogicalResult FormatParser::parseFunctionalTypeDirective( + std::unique_ptr &element, FormatToken tok, ParserContext context) { llvm::SMLoc loc = tok.getLoc(); if (context != TopLevelContext) return emitError( @@ -3204,11 +2917,14 @@ // Parse the main operand. std::unique_ptr inputs, results; - if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || failed(parseTypeDirectiveOperand(inputs)) || - failed(parseToken(Token::comma, "expected ',' after inputs argument")) || + failed(parseToken(FormatToken::comma, + "expected ',' after inputs argument")) || failed(parseTypeDirectiveOperand(results)) || - failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); element = std::make_unique(std::move(inputs), std::move(results)); @@ -3239,9 +2955,11 @@ return emitError(loc, "'ref' is only valid within a `custom` directive"); std::unique_ptr operand; - if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || failed(parseElement(operand, RefDirectiveContext)) || - failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); element = std::make_unique(std::move(operand)); @@ -3300,17 +3018,19 @@ } LogicalResult -FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, - ParserContext context) { +FormatParser::parseTypeDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context) { llvm::SMLoc loc = tok.getLoc(); if (context == TypeDirectiveContext) return emitError(loc, "'type' cannot be used as a child of another `type`"); bool isRefChild = context == RefDirectiveContext; std::unique_ptr operand; - if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || failed(parseTypeDirectiveOperand(operand, isRefChild)) || - failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); element = std::make_unique(std::move(operand)); @@ -3383,7 +3103,7 @@ OperationFormat format(op); if (failed(FormatParser(mgr, format, op).parse())) { // Exit the process if format errors are treated as fatal. - if (formatErrorIsFatal) { + if (format::formatErrorIsFatal) { // Invoke the interrupt handlers to run the file cleanup handlers. llvm::sys::RunInterruptHandlers(); std::exit(1); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -219,6 +219,7 @@ "//mlir:SideEffects", "//mlir:StandardOps", "//mlir:StandardOpsTransforms", + "//mlir:Support", "//mlir:TensorDialect", "//mlir:TransformUtils", "//mlir:Transforms",