diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp --- a/mlir/tools/mlir-tblgen/StructsGen.cpp +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -143,7 +143,7 @@ )"; for (auto field : fields) { - if (field.getType().isOptional()) + if (field.getType().isOptional() || field.getType().hasDefaultValue()) os << llvm::formatv(getFieldInfoOptional, field.getName()); else os << llvm::formatv(getFieldInfo, field.getName()); @@ -169,7 +169,7 @@ auto derived = attr.dyn_cast<::mlir::DictionaryAttr>(); if (!derived) return false; - int empty_optionals = 0; + int num_absent_attrs = 0; )"; os << llvm::formatv(classofInfo, structName) << " {"; @@ -184,7 +184,7 @@ const char *classofArgInfoOptional = R"( auto {0} = derived.get("{0}"); if (!{0}) - ++empty_optionals; + ++num_absent_attrs; else if (!({1})) return false; )"; @@ -193,14 +193,14 @@ auto type = field.getType(); std::string condition = std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name))); - if (type.isOptional()) + if (type.isOptional() || type.hasDefaultValue()) os << llvm::formatv(classofArgInfoOptional, name, condition); else os << llvm::formatv(classofArgInfo, name, condition); } const char *classofEndInfo = R"( - return derived.size() + empty_optionals == {0}; + return derived.size() + num_absent_attrs == {0}; } )"; os << llvm::formatv(classofEndInfo, fields.size()); @@ -229,14 +229,35 @@ return {1}.cast<{0}>(); } )"; + const char *fieldInfoDefaultValued = R"( +{0} {2}::{1}() const { + auto derived = this->cast<::mlir::DictionaryAttr>(); + auto {1} = derived.get("{1}"); + if (!{1}) { + ::mlir::Builder builder(getContext()); + return {3}; + } + assert({1}.isa<{0}>() && "incorrect Attribute type found."); + return {1}.cast<{0}>(); +} +)"; + FmtContext fmtCtx; + fmtCtx.withBuilder("builder"); + for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); auto storage = type.getStorageType(); - if (type.isOptional()) + if (type.isOptional()) { os << llvm::formatv(fieldInfoOptional, storage, name, structName); - else + } else if (type.hasDefaultValue()) { + std::string defaultValue = tgfmt(type.getConstBuilderTemplate(), &fmtCtx, + type.getDefaultValue()); + os << llvm::formatv(fieldInfoDefaultValued, storage, name, structName, + defaultValue); + } else { os << llvm::formatv(fieldInfo, storage, name, structName); + } } } diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/StandardTypes.h" #include "llvm/ADT/DenseMap.h" @@ -34,9 +35,10 @@ auto elementsAttr = mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6}); auto optionalAttr = nullptr; + auto defaultValuedAttr = nullptr; return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, - optionalAttr, context); + optionalAttr, defaultValuedAttr, context); } /// Validates that test::TestStruct::classof correctly identifies a valid @@ -167,4 +169,12 @@ EXPECT_EQ(structAttr.sample_optional_integer(), nullptr); } +TEST(StructsGenTest, GetDefaultValuedAttr) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto structAttr = getTestStruct(&context); + EXPECT_EQ(structAttr.sample_default_valued_integer(), + builder.getI32IntegerAttr(42)); +} + } // namespace mlir diff --git a/mlir/unittests/TableGen/structs.td b/mlir/unittests/TableGen/structs.td --- a/mlir/unittests/TableGen/structs.td +++ b/mlir/unittests/TableGen/structs.td @@ -17,6 +17,8 @@ StructFieldAttr<"sample_float", F32Attr>, StructFieldAttr<"sample_elements", I32ElementsAttr>, StructFieldAttr<"sample_optional_integer", - OptionalAttr>] > { + OptionalAttr>, + StructFieldAttr<"sample_default_valued_integer", + DefaultValuedAttr>] > { let description = "Structure for test data"; }