diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp --- a/mlir/tools/mlir-tblgen/StructsGen.cpp +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -135,8 +135,18 @@ fields.emplace_back({0}_id, {0}); )"; + const char *getFieldInfoOptional = R"( + if ({0}) { + auto {0}_id = mlir::Identifier::get("{0}", context); + fields.emplace_back({0}_id, {0}); + } +)"; + for (auto field : fields) { - os << llvm::formatv(getFieldInfo, field.getName()); + if (field.getType().isOptional()) + os << llvm::formatv(getFieldInfoOptional, field.getName()); + else + os << llvm::formatv(getFieldInfo, field.getName()); } const char *getEndInfo = R"( @@ -154,15 +164,16 @@ bool {0}::classof(mlir::Attribute attr))"; const char *classofInfoHeader = R"( - auto derived = attr.dyn_cast(); - if (!derived) - return false; - if (derived.size() != {0}) - return false; + if (!attr) + return false; + auto derived = attr.dyn_cast(); + if (!derived) + return false; + int empty_optionals = 0; )"; os << llvm::formatv(classofInfo, structName) << " {"; - os << llvm::formatv(classofInfoHeader, fields.size()); + os << llvm::formatv(classofInfoHeader); FmtContext fctx; const char *classofArgInfo = R"( @@ -170,19 +181,29 @@ if (!{0} || !({1})) return false; )"; + const char *classofArgInfoOptional = R"( + auto {0} = derived.get("{0}"); + if (!{0}) + ++empty_optionals; + else if (!({1})) + return false; +)"; for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); std::string condition = std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name))); - os << llvm::formatv(classofArgInfo, name, condition); + if (type.isOptional()) + os << llvm::formatv(classofArgInfoOptional, name, condition); + else + os << llvm::formatv(classofArgInfo, name, condition); } const char *classofEndInfo = R"( - return true; + return derived.size() + empty_optionals == {0}; } )"; - os << classofEndInfo; + os << llvm::formatv(classofEndInfo, fields.size()); } static void @@ -198,11 +219,24 @@ return {1}.cast<{0}>(); } )"; + const char *fieldInfoOptional = R"( +{0} {2}::{1}() const { + auto derived = this->cast(); + auto {1} = derived.get("{1}"); + if (!{1}) + return nullptr; + assert({1}.isa<{0}>() && "incorrect Attribute type found."); + return {1}.cast<{0}>(); +} +)"; for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); auto storage = type.getStorageType(); - os << llvm::formatv(fieldInfo, storage, name, structName); + if (type.isOptional()) + os << llvm::formatv(fieldInfoOptional, storage, name, structName); + else + os << llvm::formatv(fieldInfo, storage, name, structName); } } diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -33,8 +33,10 @@ auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType); auto elementsAttr = mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6}); + auto optionalAttr = nullptr; - return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, context); + return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, + optionalAttr, context); } // Validates that test::TestStruct::classof correctly identifies a valid @@ -159,4 +161,10 @@ } } +TEST(StructsGenTest, EmptyOptional) { + mlir::MLIRContext context; + auto structAttr = getTestStruct(&context); + EXPECT_EQ(structAttr.sample_optional_integer(), nullptr); +} + } // namespace mlir diff --git a/mlir/unittests/TableGen/structs.td b/mlir/unittests/TableGen/structs.td --- a/mlir/unittests/TableGen/structs.td +++ b/mlir/unittests/TableGen/structs.td @@ -15,6 +15,8 @@ def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [ StructFieldAttr<"sample_integer", I32Attr>, StructFieldAttr<"sample_float", F32Attr>, - StructFieldAttr<"sample_elements", I32ElementsAttr>] > { + StructFieldAttr<"sample_elements", I32ElementsAttr>, + StructFieldAttr<"sample_optional_integer", + OptionalAttr>] > { let description = "Structure for test data"; }