diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1283,10 +1283,8 @@ Some attributes can only take values from a predefined enum, e.g., the comparison kind of a comparison op. To define such attributes, ODS provides -several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`. +several mechanisms: `IntEnumAttr`, and `BitEnumAttr`. -* `StrEnumAttr`: each enum case is a string, the attribute is stored as a - [`StringAttr`][StringAttr] in the op. * `IntEnumAttr`: each enum case is an integer, the attribute is stored as a [`IntegerAttr`][IntegerAttr] in the op. * `BitEnumAttr`: each enum case is a either the empty case, a single bit, diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1230,13 +1230,6 @@ string str = strVal; } -// An enum attribute case stored with StringAttr. -class StrEnumAttrCase : - EnumAttrCaseInfo, - StringBasedAttr< - CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">, - "case " # str>; - // An enum attribute case stored with IntegerAttr, which has an integer value, // its representation as a string and a C++ symbol name which may be different. class IntEnumAttrCaseBase : @@ -1393,22 +1386,6 @@ let valueType = baseAttrClass.valueType; } -// An enum attribute backed by StringAttr. -// -// Op attributes of this kind are stored as StringAttr. Extra verification will -// be generated on the string though: only the symbols of the allowed cases are -// permitted as the string value. -class StrEnumAttr cases> : - EnumAttrInfo]>, - !if(!empty(summary), "allowed string cases: " # - !interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "), - summary)>> { - // Disable specialized Attribute class for `StringAttr` backend by default. - let genSpecializedAttr = 0; -} - // An enum attribute backed by IntegerAttr. // // Op attributes of this kind are stored as IntegerAttr. Extra verification will diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -144,9 +144,6 @@ explicit EnumAttrCase(const llvm::Record *record); explicit EnumAttrCase(const llvm::DefInit *init); - // Returns true if this EnumAttrCase is backed by a StringAttr. - bool isStrCase() const; - // Returns the symbol of this enum attribute case. StringRef getSymbol() const; diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -157,8 +157,6 @@ EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) : EnumAttrCase(init->getDef()) {} -bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); } - StringRef EnumAttrCase::getSymbol() const { return def->getValueAsString("symbol"); } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -345,29 +345,6 @@ // ----- -//===----------------------------------------------------------------------===// -// Test StrEnumAttr -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @allowed_cases_pass -func @allowed_cases_pass() { - // CHECK: test.str_enum_attr - %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32 - // CHECK: test.str_enum_attr - %1 = "test.str_enum_attr"() {attr = "B"} : () -> i32 - return -} - -// ----- - -func @disallowed_case_fail() { - // expected-error @+1 {{allowed string cases: 'A', 'B'}} - %0 = "test.str_enum_attr"() {attr = 7: i32} : () -> i32 - return -} - -// ----- - //===----------------------------------------------------------------------===// // Test I32EnumAttr //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -191,17 +191,6 @@ let assemblyFormat = "$attr attr-dict"; } -def StrCaseA: StrEnumAttrCase<"A">; -def StrCaseB: StrEnumAttrCase<"B">; - -def SomeStrEnum: StrEnumAttr< - "SomeStrEnum", "", [StrCaseA, StrCaseB]>; - -def StrEnumAttrOp : TEST_Op<"str_enum_attr"> { - let arguments = (ins SomeStrEnum:$attr); - let results = (outs I32:$val); -} - def I32Case5: I32EnumAttrCase<"case5", 5>; def I32Case10: I32EnumAttrCase<"case10", 10>; @@ -1260,8 +1249,6 @@ def OpC : TEST_Op<"op_c">, Arguments<(ins I32)>, Results<(outs I32)>; def : Pat<(OpC $input), (OpB $input, ConstantAttr:$attr)>; -// Test string enum attribute in rewrites. -def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>; // Test integer enum attribute in rewrites. def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; @@ -1568,11 +1555,8 @@ // Test Legalization //===----------------------------------------------------------------------===// -def Test_LegalizerEnum_Success : StrEnumAttrCase<"Success">; -def Test_LegalizerEnum_Failure : StrEnumAttrCase<"Failure">; - -def Test_LegalizerEnum : StrEnumAttr<"Success", "Failure", - [Test_LegalizerEnum_Success, Test_LegalizerEnum_Failure]>; +def Test_LegalizerEnum_Success : ConstantStrAttr; +def Test_LegalizerEnum_Failure : ConstantStrAttr; def ILLegalOpA : TEST_Op<"illegal_op_a">, Results<(outs I32)>; def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32)>; @@ -1582,7 +1566,7 @@ def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>; def LegalOpA : TEST_Op<"legal_op_a">, - Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; + Arguments<(ins StrAttr:$status)>, Results<(outs I32)>; def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; def LegalOpC : TEST_Op<"legal_op_c">, Arguments<(ins I32)>, Results<(outs I32)>; diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -356,13 +356,6 @@ // Test Enum Attributes //===----------------------------------------------------------------------===// -// CHECK-LABEL: verifyStrEnumAttr -func @verifyStrEnumAttr() -> i32 { - // CHECK: "test.str_enum_attr"() {attr = "B"} - %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32 - return %0 : i32 -} - // CHECK-LABEL: verifyI32EnumAttr func @verifyI32EnumAttr() -> i32 { // CHECK: "test.i32_enum_attr"() {attr = 10 : i32} diff --git a/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp b/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp --- a/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp +++ b/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp @@ -35,7 +35,7 @@ // declarations, functions etc. // // Some OpenMP/OpenACC clauses accept only a fixed set of values as inputs. -// These can be represented as a String Enum Attribute (StrEnumAttr) in MLIR +// These can be represented as a Enum Attributes (EnumAttrDef) in MLIR // ODS. The emitDecls function below currently generates these enumerations. The // name of the enumeration is specified in the enumClauseValue field of // Clause record in OMP.td. This name can be used to specify the type of the diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -314,8 +314,6 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); - StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); - StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass(); Attribute baseAttr(baseAttrDef); @@ -341,28 +339,22 @@ os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n", attrClassName, enumName); - if (enumAttr.isSubClassOf("StrEnumAttr")) { - os << formatv(" ::mlir::StringAttr baseAttr = " - "::mlir::StringAttr::get(context, {0}(val));\n", - symToStrFnName); - } else { - StringRef underlyingType = enumAttr.getUnderlyingType(); - - // Assuming that it is IntegerAttr constraint - int64_t bitwidth = 64; - if (baseAttrDef->getValue("valueType")) { - auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType"); - if (valueTypeDef->getValue("bitwidth")) - bitwidth = valueTypeDef->getValueAsInt("bitwidth"); - } + StringRef underlyingType = enumAttr.getUnderlyingType(); - os << formatv(" ::mlir::IntegerType intType = " - "::mlir::IntegerType::get(context, {0});\n", - bitwidth); - os << formatv(" ::mlir::IntegerAttr baseAttr = " - "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n", - underlyingType); + // Assuming that it is IntegerAttr constraint + int64_t bitwidth = 64; + if (baseAttrDef->getValue("valueType")) { + auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType"); + if (valueTypeDef->getValue("bitwidth")) + bitwidth = valueTypeDef->getValueAsInt("bitwidth"); } + + os << formatv(" ::mlir::IntegerType intType = " + "::mlir::IntegerType::get(context, {0});\n", + bitwidth); + os << formatv(" ::mlir::IntegerAttr baseAttr = " + "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n", + underlyingType); os << formatv(" return baseAttr.cast<{0}>();\n", attrClassName); os << "}\n"; @@ -371,14 +363,8 @@ os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName); - if (enumAttr.isSubClassOf("StrEnumAttr")) { - os << formatv(" const auto res = {0}(::mlir::StringAttr::getValue());\n", - strToSymFnName); - os << " return res.getValue();\n"; - } else { - os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n", - enumName); - } + os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n", + enumName); os << "}\n"; } @@ -483,8 +469,7 @@ )"; if (enumAttr.genSpecializedAttr()) { StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); - StringRef baseAttrClassName = - enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr"; + StringRef baseAttrClassName = "IntegerAttr"; os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1797,30 +1797,11 @@ // Get a string containing all of the cases that can't be represented with a // keyword. BitVector nonKeywordCases(cases.size()); - bool hasStrCase = false; for (auto &it : llvm::enumerate(cases)) { - hasStrCase = it.value().isStrCase(); if (!canFormatStringAsKeyword(it.value().getStr())) nonKeywordCases.set(it.index()); } - // If this is a string enum, use the case string to determine which cases - // need to use the string form. - if (hasStrCase) { - if (nonKeywordCases.any()) { - body << " if (llvm::is_contained(llvm::ArrayRef("; - llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) { - body << '"' << cases[it].getStr() << '"'; - }); - body << ")))\n" - " _odsPrinter << '\"' << caseValueStr << '\"';\n" - " else\n "; - } - body << " _odsPrinter << caseValueStr;\n" - " }\n"; - return; - } - // Otherwise if this is a bit enum attribute, don't allow cases that may // overlap with other cases. For simplicity sake, only allow cases with a // single bit value. diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1221,8 +1221,6 @@ } if (leaf.isEnumAttrCase()) { auto enumCase = leaf.getAsEnumAttrCase(); - if (enumCase.isStrCase()) - return handleConstantAttr(enumCase, "\"" + enumCase.getSymbol() + "\""); // This is an enum case backed by an IntegerAttr. We need to get its value // to build the constant. std::string val = std::to_string(enumCase.getValue()); diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -27,12 +27,12 @@ /// Test namespaces and enum class/utility names. using Outer::Inner::ConvertToEnum; using Outer::Inner::ConvertToString; -using Outer::Inner::StrEnum; -using Outer::Inner::StrEnumAttr; +using Outer::Inner::FooEnum; +using Outer::Inner::FooEnumAttr; TEST(EnumsGenTest, GeneratedStrEnumDefinition) { - EXPECT_EQ(0u, static_cast(StrEnum::CaseA)); - EXPECT_EQ(10u, static_cast(StrEnum::CaseB)); + EXPECT_EQ(0u, static_cast(FooEnum::CaseA)); + EXPECT_EQ(1u, static_cast(FooEnum::CaseB)); } TEST(EnumsGenTest, GeneratedI32EnumDefinition) { @@ -41,23 +41,23 @@ } TEST(EnumsGenTest, GeneratedDenseMapInfo) { - llvm::DenseMap myMap; + llvm::DenseMap myMap; - myMap[StrEnum::CaseA] = "zero"; - myMap[StrEnum::CaseB] = "one"; + myMap[FooEnum::CaseA] = "zero"; + myMap[FooEnum::CaseB] = "one"; - EXPECT_EQ(myMap[StrEnum::CaseA], "zero"); - EXPECT_EQ(myMap[StrEnum::CaseB], "one"); + EXPECT_EQ(myMap[FooEnum::CaseA], "zero"); + EXPECT_EQ(myMap[FooEnum::CaseB], "one"); } TEST(EnumsGenTest, GeneratedSymbolToStringFn) { - EXPECT_EQ(ConvertToString(StrEnum::CaseA), "CaseA"); - EXPECT_EQ(ConvertToString(StrEnum::CaseB), "CaseB"); + EXPECT_EQ(ConvertToString(FooEnum::CaseA), "CaseA"); + EXPECT_EQ(ConvertToString(FooEnum::CaseB), "CaseB"); } TEST(EnumsGenTest, GeneratedStringToSymbolFn) { - EXPECT_EQ(llvm::Optional(StrEnum::CaseA), ConvertToEnum("CaseA")); - EXPECT_EQ(llvm::Optional(StrEnum::CaseB), ConvertToEnum("CaseB")); + EXPECT_EQ(llvm::Optional(FooEnum::CaseA), ConvertToEnum("CaseA")); + EXPECT_EQ(llvm::Optional(FooEnum::CaseB), ConvertToEnum("CaseB")); EXPECT_EQ(llvm::None, ConvertToEnum("X")); } @@ -155,19 +155,6 @@ EXPECT_EQ(intAttr, enumAttr); } -TEST(EnumsGenTest, GeneratedStringAttributeClass) { - mlir::MLIRContext ctx; - StrEnum rawVal = StrEnum::CaseA; - - StrEnumAttr enumAttr = StrEnumAttr::get(&ctx, rawVal); - EXPECT_NE(enumAttr, nullptr); - EXPECT_EQ(enumAttr.getValue(), rawVal); - - mlir::Attribute strAttr = mlir::StringAttr::get(&ctx, "CaseA"); - EXPECT_TRUE(strAttr.isa()); - EXPECT_EQ(strAttr, enumAttr); -} - TEST(EnumsGenTest, GeneratedBitAttributeClass) { mlir::MLIRContext ctx; diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -8,10 +8,10 @@ include "mlir/IR/OpBase.td" -def CaseA: StrEnumAttrCase<"CaseA">; -def CaseB: StrEnumAttrCase<"CaseB", 10>; +def CaseA: I32EnumAttrCase<"CaseA", 0>; +def CaseB: I32EnumAttrCase<"CaseB", 1>; -def StrEnum: StrEnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> { +def FooEnum: I32EnumAttr<"FooEnum", "A test enum", [CaseA, CaseB]> { let cppNamespace = "Outer::Inner"; let stringToSymbolFnName = "ConvertToEnum"; let symbolToStringFnName = "ConvertToString";