diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp --- a/mlir/tools/mlir-tblgen/StructsGen.cpp +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -27,6 +27,7 @@ using llvm::Record; using llvm::RecordKeeper; using llvm::StringRef; +using mlir::tblgen::FmtContext; using mlir::tblgen::StructAttr; static void @@ -163,16 +164,18 @@ os << llvm::formatv(classofInfo, structName) << " {"; os << llvm::formatv(classofInfoHeader, fields.size()); + FmtContext fctx; const char *classofArgInfo = R"( auto {0} = derived.get("{0}"); - if (!{0} || !{0}.isa<{1}>()) + if (!{0} || !({1})) return false; )"; for (auto field : fields) { auto name = field.getName(); auto type = field.getType(); - auto storage = type.getStorageType(); - os << llvm::formatv(classofArgInfo, name, storage); + std::string condition = + tgfmt(type.getConditionTemplate(), &fctx.withSelf(name)); + os << llvm::formatv(classofArgInfo, name, condition); } const char *classofEndInfo = R"( diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -27,12 +27,12 @@ auto integerType = mlir::IntegerType::get(32, context); auto integerAttr = mlir::IntegerAttr::get(integerType, 127); - auto floatType = mlir::FloatType::getF16(context); + auto floatType = mlir::FloatType::getF32(context); auto floatAttr = mlir::FloatAttr::get(floatType, 0.25); auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType); auto elementsAttr = - mlir::DenseElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6}); + mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6}); return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, context); } @@ -88,6 +88,31 @@ ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } +// Validates that test::TestStruct::classof fails when a NamedAttribute has an +// incorrect type. +TEST(StructsGenTest, ClassofBadTypeFalse) { + mlir::MLIRContext context; + mlir::DictionaryAttr structAttr = getTestStruct(&context); + auto expectedValues = structAttr.getValue(); + ASSERT_EQ(expectedValues.size(), 3u); + + // Create a copy of all but the last NamedAttributes. + llvm::SmallVector newValues( + expectedValues.begin(), expectedValues.end() - 1); + + // Add a copy of the last attribute with the wrong type. + auto i64Type = mlir::IntegerType::get(64, &context); + auto elementsType = mlir::RankedTensorType::get({3}, i64Type); + auto elementsAttr = + mlir::DenseIntElementsAttr::get(elementsType, ArrayRef{1, 2, 3}); + mlir::Identifier id = expectedValues.back().first; + auto wrongAttr = mlir::NamedAttribute(id, elementsAttr); + newValues.push_back(wrongAttr); + + auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + ASSERT_FALSE(test::TestStruct::classof(badDictionary)); +} + // Validates that test::TestStruct::classof fails when a NamedAttribute is // missing. TEST(StructsGenTest, ClassofMissingFalse) { diff --git a/mlir/unittests/TableGen/structs.td b/mlir/unittests/TableGen/structs.td --- a/mlir/unittests/TableGen/structs.td +++ b/mlir/unittests/TableGen/structs.td @@ -15,6 +15,6 @@ def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [ StructFieldAttr<"sample_integer", I32Attr>, StructFieldAttr<"sample_float", F32Attr>, - StructFieldAttr<"sample_elements", ElementsAttr>] > { + StructFieldAttr<"sample_elements", I32ElementsAttr>] > { let description = "Structure for test data"; }