diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -182,6 +182,14 @@ cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)", baseAttrClass.constBuilderCall); let valueType = baseAttrClass.valueType; + + // C++ type wrapped by attribute + string cppType = cppNamespace # "::" # className; + + // Parser and printer code used by the EnumParameter class, to be provided by + // derived classes + string parameterParser = ?; + string parameterPrinter = ?; } // An enum attribute backed by IntegerAttr. @@ -202,7 +210,25 @@ IntEnumAttrBase>; + summary)>> { + // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the + // symbol is not valid. + let parameterParser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> { + auto loc = $_parser.getCurrentLocation(); + ::llvm::StringRef enumKeyword; + if (::mlir::failed($_parser.parseKeyword(&enumKeyword))) + return ::mlir::failure(); + auto maybeEnum = }] # cppNamespace # "::" # + stringToSymbolFnName # [{(enumKeyword); + if (maybeEnum) + return *maybeEnum; + return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] # + cppType # [{ to be one of: }] # + !interleave(!foreach(enum, enumerants, enum.str), ", ") # [{")}; + }()}]; + // Print the enum by calling `symbolToString`. + let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)"; +} class I32EnumAttr cases> : IntEnumAttr { @@ -243,6 +269,36 @@ // The delimiter used to separate bit enum cases in strings. string separator = "|"; + + // Parsing function that corresponds to the enum separator. Only + // "," and "|" are supported by this definition. + string parseSeparatorFn = !if(!eq(separator,"|"),"parseOptionalVerticalBar", + "parseOptionalComma"); + + // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the + // symbol is not valid. + let parameterParser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> { + }] # cppType # [{ flags = {}; + auto loc = $_parser.getCurrentLocation(); + ::llvm::StringRef enumKeyword; + do { + if (::mlir::failed($_parser.parseKeyword(&enumKeyword))) + return ::mlir::failure(); + auto maybeEnum = }] # cppNamespace # "::" # + stringToSymbolFnName # [{(enumKeyword); + if (!maybeEnum) { + return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] # + cppType # [{ to be one of: }] # + !interleave(!foreach(enum, enumerants, enum.str), + ", ") # [{")}; + } + flags = flags | *maybeEnum; + } while(::mlir::succeeded($_parser.}] # parseSeparatorFn # [{())); + return flags; + }()}]; + // Print the enum by calling `symbolToString`. + let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)"; + } class I32BitEnumAttr : AttrParameter { - // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the - // symbol is not valid. - let parser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> { - auto loc = $_parser.getCurrentLocation(); - ::llvm::StringRef enumKeyword; - if (::mlir::failed($_parser.parseKeyword(&enumKeyword))) - return ::mlir::failure(); - auto maybeEnum = }] # enumInfo.cppNamespace # "::" # - enumInfo.stringToSymbolFnName # [{(enumKeyword); - if (maybeEnum) - return *maybeEnum; - return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] # - cppType # [{ to be one of: }] # - !interleave(!foreach(enum, enumInfo.enumerants, enum.str), ", ") # [{")}; - }()}]; - // Print the enum by calling `symbolToString`. - let printer = "$_printer << " # enumInfo.symbolToStringFnName # "($_self)"; + let parser = enumInfo.parameterParser; + let printer = enumInfo.parameterPrinter; } // An attribute backed by a C++ enum. The attribute contains a single diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -464,6 +464,12 @@ /// Parse a '*' token if present. virtual ParseResult parseOptionalStar() = 0; + /// Parse a '|' token. + virtual ParseResult parseVerticalBar() = 0; + + /// Parse a '|' token if present. + virtual ParseResult parseOptionalVerticalBar() = 0; + /// Parse a quoted string token. ParseResult parseString(std::string *string) { auto loc = getCurrentLocation(); diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -221,6 +221,16 @@ return success(parser.consumeIf(Token::plus)); } + /// Parse a '|' token. + virtual ParseResult parseVerticalBar() override { + return parser.parseToken(Token::vertical_bar, "expected '|'"); + } + + /// Parse a '|' token if present. + virtual ParseResult parseOptionalVerticalBar() override { + return success(parser.consumeIf(Token::vertical_bar)); + } + /// Parses a quoted string token if present. ParseResult parseOptionalString(std::string *string) override { if (!parser.getToken().is(Token::string)) diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -70,6 +70,7 @@ TOK_PUNCTUATION(r_paren, ")") TOK_PUNCTUATION(r_square, "]") TOK_PUNCTUATION(star, "*") +TOK_PUNCTUATION(vertical_bar, "|") // Keywords. These turn "foo" into Token::kw_foo enums. diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -407,6 +407,29 @@ // ----- +//===----------------------------------------------------------------------===// +// Test BitEnumAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @allowed_cases_pass +func @allowed_cases_pass() { + // CHECK: test.op_with_bit_enum + "test.op_with_bit_enum"() {value = #test.bit_enum} : () -> () + // CHECK: test.op_with_bit_enum + test.op_with_bit_enum + return +} + +// ----- + +func @disallowed_case_sticky_fail() { + // expected-error@+2 {{expected test::TestBitEnum to be one of: read, write, execute}} + // expected-error@+1 {{failed to parse TestBitEnumAttr}} + "test.op_with_bit_enum"() {value = #test.bit_enum} : () -> () +} + +// ----- + //===----------------------------------------------------------------------===// // Test FloatElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -22,6 +22,7 @@ #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" // Include this before the using namespace lines below to diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -310,6 +310,33 @@ "::test::TestEnum::Second">, ConstantAttr)>; +//===----------------------------------------------------------------------===// +// Test Bit Enum Attributes +//===----------------------------------------------------------------------===// + +// Define the C++ enum. +def TestBitEnum + : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [ + I32BitEnumAttrCaseBit<"Read", 0, "read">, + I32BitEnumAttrCaseBit<"Write", 1, "write">, + I32BitEnumAttrCaseBit<"Execute", 2, "execute">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "test"; + let separator = ","; +} + +// Define the enum attribute. +def TestBitEnumAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +// Define an op that contains the enum attribute. +def OpWithBitEnum : TEST_Op<"op_with_bit_enum"> { + let arguments = (ins TestBitEnumAttr:$value, OptionalAttr:$tag); + let assemblyFormat = "$value (`tag` $tag^)? attr-dict"; +} + //===----------------------------------------------------------------------===// // Test Attribute Constraints //===----------------------------------------------------------------------===//