diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -207,6 +207,92 @@ llvm::errs() << "hook returned = " << example.exampleInterfaceHook() << "\n"; ``` +#### External Models for Attribute/Type Interfaces + +It may be desirable to provide an interface implementation for an attribute or a +type without modifying the definition of said attribute or type. Notably, this +allows to implement interfaces for attributes and types outside of the dialect +that defines them and, in particular, provide interfaces for built-in types. + +This is achieved by extending the concept-based polymorphism model with two more +classes derived from `Concept` as follows. + +```c++ +struct ExampleTypeInterfaceTraits { + struct Concept { + virtual unsigned exampleInterfaceHook(Type type) const = 0; + virtual unsigned exampleStaticInterfaceHook() const = 0; + }; + + template + struct Model : public Concept { /*...*/ }; + + + /// Unlike `Model`, `FallbackModel` passes the type object through to the + /// hook, making it accessible in the method body even if the method is not + /// defined in the class itself and thus has no `this` access. ODS + /// automatically generates this class for all interfaces. + template + struct FallbackModel : public Concept { + unsigned exampleInterfaceHook(Type type) const override { + getImpl()->exampleInterfaceHook(type); + } + unsigned exampleStaticInterfaceHook() const override { + ConcreteType::exampleStaticInterfaceHook(); + } + }; + + /// `ExternalModel` provides a place for default implementations of interface + /// methods by explicitly separating the model class, which implements the + /// interface, from the type class, for which the interface is being + /// implemented. Default implemenations can be then defined generically making + /// use of `cast`. If `ConcreteType` does not provide the APIs + /// required by the default implementation, custom implementations may use + /// `FallbackModel` directly to override the default implementation. Being + /// located in a class template, it never gets instantiated and does not lead + /// to compilation errors. ODS automatically generates this class and places + /// default method implementations in it. + template + struct ExternalModel : public FallbackModel { + unsigned exampleInterfaceHook(Type type) const override { + // Default implementation can be provided here. + return type.cast().callSomeTypeSpecificMethod(); + } + }; +}; +``` + +External models can be provided for attirbute and type interfaces by deriving +either `FallbackModel` or `ExternalModel` and by registering the model class +with the attribute or type class in a given context. Other contexts will not see +the interface unless registered. + +```c++ +/// External interface implementation for a concrete class. This does not +/// require modifying the definition of the type class itself. +struct ExternalModelExample + : public ExampleTypeInterface::ExternalModel { + static unsigned exampleStaticInterfaceHook() { + // Implementation is provided here. + return IntegerType::someStaticMethod(); + } + + // No need to define `exampleInterfaceHook` that has a default implementation + // in `ExternalModel`. But it can be overridden if desired. +} + +int main() { + MLIRContext context; + /* ... */; + + // Register the interface model with the type in the given context before + // using it. The dialect contaiing the type is expected to have been loaded + // at this point. + IntegerType::registerInterface(context); +} +``` + #### Dialect Fallback for OpInterface Some dialects have an open ecosystem and don't register all of the possible @@ -215,9 +301,9 @@ registered or does not provide an implementation for an interface, the query will fallback to the dialect itself. -A second model is used for such cases and automatically generated when -using ODS (see below) with the name `FallbackModel`. This model can be implemented -for a particular dialect: +A second model is used for such cases and automatically generated when using ODS +(see below) with the name `FallbackModel`. This model can be implemented for a +particular dialect: ```c++ // This is the implementation of a dialect fallback for `ExampleOpInterface`. diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -59,14 +59,24 @@ : dialect(dialect), interfaceMap(std::move(interfaceMap)), typeID(typeID) {} + /// Give StorageUserBase access to the mutable lookup. + template class... Traits> + friend class detail::StorageUserBase; + + /// Look up the specified abstract attribute in the MLIRContext and return a + /// (mutable) pointer to it. Return a null pointer if the attribute could not + /// be found in the context. + static AbstractAttribute *lookupMutable(TypeID typeID, MLIRContext *context); + /// This is the dialect that this attribute was registered to. - Dialect &dialect; + const Dialect &dialect; /// This is a collection of the interfaces registered to this attribute. detail::InterfaceMap interfaceMap; /// The unique identifier of the derived Attribute class. - TypeID typeID; + const TypeID typeID; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -31,6 +31,7 @@ using ImplType = AttributeStorage; using ValueType = void; + using Abstract = AbstractAttribute; constexpr Attribute() : impl(nullptr) {} /* implicit */ Attribute(const ImplType *impl) diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -87,6 +87,18 @@ return detail::InterfaceMap::template get...>(); } + /// Attach the given models as implementations of the corresponding interfaces + /// for the concrete storage user class. + template + static void attachInterface(MLIRContext &context) { + typename ConcreteT::Abstract *abstract = + ConcreteT::Abstract::lookupMutable(TypeID::get(), &context); + if (!abstract) + llvm::report_fatal_error("registering an interface for an attribute/type " + "that is not itself registered"); + abstract->interfaceMap.template insert(); + } + /// Get or create a new ConcreteT instance within the ctx. This /// function is guaranteed to return a non null object and will assert if /// the arguments provided are invalid. diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -66,15 +66,24 @@ TypeID typeID) : dialect(dialect), interfaceMap(std::move(interfaceMap)), typeID(typeID) {} + /// Give StorageUserBase access to the mutable lookup. + template class... Traits> + friend class detail::StorageUserBase; + + /// Look up the specified abstract type in the MLIRContext and return a + /// (mutable) pointer to it. Return a null pointer if the type could not + /// be found in the context. + static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context); /// This is the dialect that this type was registered to. - Dialect &dialect; + const Dialect &dialect; /// This is a collection of the interfaces registered to this type. detail::InterfaceMap interfaceMap; /// The unique identifier of the derived Type class. - TypeID typeID; + const TypeID typeID; }; //===----------------------------------------------------------------------===// @@ -105,11 +114,11 @@ /// Set the abstract type for this storage instance. This is used by the /// TypeUniquer when initializing a newly constructed type storage object. void initialize(const AbstractType &abstractTy) { - abstractType = &abstractTy; + abstractType = const_cast(&abstractTy); } /// The abstract description for this type. - const AbstractType *abstractType; + AbstractType *abstractType; }; /// Default storage type for types that require no additional initialization or diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -79,6 +79,8 @@ using ImplType = TypeStorage; + using Abstract = AbstractType; + constexpr Type() : impl(nullptr) {} /* implicit */ Type(const ImplType *impl) : impl(const_cast(impl)) {} diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -76,6 +76,8 @@ using FallbackModel = typename Traits::template FallbackModel; using InterfaceBase = Interface; + template + using ExternalModel = typename Traits::template ExternalModel; /// This is a special trait that registers a given interface with an object. template @@ -199,6 +201,26 @@ }); } + /// Insert the given models as implementations of the corresponding interfaces + /// for the concrete attribute class. + template + void insert() { + std::pair elements[] = { + std::make_pair(IfaceModels::Interface::getInterfaceID(), + new (malloc(sizeof(IfaceModels))) IfaceModels())...}; + // Insert directly into the right position to keep the interfaces sorted. + for (auto &element : elements) { + TypeID id = element.first; + auto it = + llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) { + return compare(it.first, id); + }); + if (it != interfaces.end() && it->first == id) + llvm::report_fatal_error("Interface already registered"); + interfaces.insert(it, element); + } + } + private: /// Compare two TypeID instances by comparing the underlying pointer. static bool compare(TypeID lhs, TypeID rhs) { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -306,7 +306,7 @@ // Type uniquing //===--------------------------------------------------------------------===// - DenseMap registeredTypes; + DenseMap registeredTypes; StorageUniquer typeUniquer; /// Cached Type Instances. @@ -324,7 +324,7 @@ // Attribute uniquing //===--------------------------------------------------------------------===// - DenseMap registeredAttributes; + DenseMap registeredAttributes; StorageUniquer attributeUniquer; /// Cached Attribute Instances. @@ -669,12 +669,20 @@ /// Get the dialect that registered the attribute with the provided typeid. const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, MLIRContext *context) { + const AbstractAttribute *abstract = lookupMutable(typeID, context); + if (!abstract) + llvm::report_fatal_error("Trying to create an Attribute that was not " + "registered in this MLIRContext."); + return *abstract; +} + +AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, + MLIRContext *context) { auto &impl = context->getImpl(); auto it = impl.registeredAttributes.find(typeID); if (it == impl.registeredAttributes.end()) - llvm::report_fatal_error("Trying to create an Attribute that was not " - "registered in this MLIRContext."); - return *it->second; + return nullptr; + return it->second; } //===----------------------------------------------------------------------===// @@ -740,12 +748,19 @@ //===----------------------------------------------------------------------===// const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { + const AbstractType *type = lookupMutable(typeID, context); + if (!type) + llvm::report_fatal_error( + "Trying to create a Type that was not registered in this MLIRContext."); + return *type; +} + +AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { auto &impl = context->getImpl(); auto it = impl.registeredTypes.find(typeID); if (it == impl.registeredTypes.end()) - llvm::report_fatal_error( - "Trying to create a Type that was not registered in this MLIRContext."); - return *it->second; + return nullptr; + return it->second; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -5,6 +5,8 @@ ) set(LLVM_TARGET_DEFINITIONS TestInterfaces.td) +mlir_tablegen(TestAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(TestAttrInterfaces.cpp.inc -gen-attr-interface-defs) mlir_tablegen(TestTypeInterfaces.h.inc -gen-type-interface-decls) mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs) mlir_tablegen(TestOpInterfaces.h.inc -gen-op-interface-decls) diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -21,6 +21,8 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "TestAttrInterfaces.h.inc" + #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -93,6 +93,8 @@ // Tablegen Generated Definitions //===----------------------------------------------------------------------===// +#include "TestAttrInterfaces.cpp.inc" + #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -54,6 +54,41 @@ }]; } +def TestExternalTypeInterface : TypeInterface<"TestExternalTypeInterface"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"Returns the bitwidth of the type plus 'arg'.", + "unsigned", "getBitwidthPlusArg", (ins "unsigned":$arg)>, + StaticInterfaceMethod<"Returns some value plus 'arg'.", + "unsigned", "staticGetSomeValuePlusArg", (ins "unsigned":$arg)>, + InterfaceMethod<"Returns the argument doubled.", + "unsigned", "getBitwidthPlusDoubleArgument", (ins "unsigned":$arg), "", + [{return $_type.getIntOrFloatBitWidth() + 2 * arg;}]>, + StaticInterfaceMethod<"Returns the argument.", + "unsigned", "staticGetArgument", (ins "unsigned":$arg), "", + [{return arg;}]>, + ]; +} + +def TestExternalFallbackTypeInterface + : TypeInterface<"TestExternalFallbackTypeInterface"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"Returns the bitwidth of the given integer type.", + "unsigned", "getBitwidth", (ins), "", [{return $_type.getWidth();}]>, + ]; +} + +def TestExternalAttrInterface : AttrInterface<"TestExternalAttrInterface"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"Gets the dialect pointer.", "const ::mlir::Dialect *", + "getDialectPtr", (ins)>, + StaticInterfaceMethod<"Returns some number.", "int", "getSomeNumber", + (ins)>, + ]; +} + def TestEffectOpInterface : EffectOpInterfaceBase<"TestEffectOpInterface", "::mlir::TestEffects::Effect"> { diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -217,9 +217,13 @@ } void InterfaceGenerator::emitModelDecl(Interface &interface) { + // Emit the basic model and the fallback model. for (const char *modelClass : {"Model", "FallbackModel"}) { os << " template\n"; os << " class " << modelClass << " : public Concept {\n public:\n"; + os << " using Interface = " << interface.getCppNamespace() + << (interface.getCppNamespace().empty() ? "" : "::") + << interface.getName() << ";\n"; os << " " << modelClass << "() : Concept{"; llvm::interleaveComma( interface.getMethods(), os, @@ -236,6 +240,40 @@ } os << " };\n"; } + + // Emit the template for the external model. + os << " template\n"; + os << " class ExternalModel : public FallbackModel {\n"; + os << " public:\n"; + + // Emit declarations for methods that have default implementations. Other + // methods are expected to be implemented by the concrete derived model. + for (auto &method : interface.getMethods()) { + if (!method.getDefaultImplementation()) + continue; + os << " "; + if (method.isStatic()) + os << "static "; + emitCPPType(method.getReturnType(), os); + os << method.getName() << "("; + if (!method.isStatic()) { + emitCPPType(valueType, os); + os << "tablegen_opaque_val"; + if (!method.arg_empty()) + os << ", "; + } + llvm::interleaveComma(method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { + emitCPPType(arg.type, os); + os << arg.name; + }); + os << ")"; + if (!method.isStatic()) + os << " const"; + os << ";\n"; + } + os << " };\n"; } void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { @@ -298,6 +336,42 @@ [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n}\n"; } + + // Emit default implementations for the external model. + for (auto &method : interface.getMethods()) { + if (!method.getDefaultImplementation()) + continue; + os << "template\n"; + emitCPPType(method.getReturnType(), os); + os << "detail::" << interface.getName() + << "InterfaceTraits::ExternalModel::"; + + os << method.getName() << "("; + if (!method.isStatic()) { + emitCPPType(valueType, os); + os << "tablegen_opaque_val"; + if (!method.arg_empty()) + os << ", "; + } + llvm::interleaveComma(method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { + emitCPPType(arg.type, os); + os << arg.name; + }); + os << ")"; + if (!method.isStatic()) + os << " const"; + + os << " {\n"; + + // Use the empty context for static methods. + tblgen::FmtContext ctx; + os << tblgen::tgfmt(method.getDefaultImplementation()->trim(), + method.isStatic() ? &ctx : &nonStaticMethodFmt); + os << "\n}\n"; + } } void InterfaceGenerator::emitTraitDecl(Interface &interface, diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -227,6 +227,8 @@ << " }\n" << " };\n"; } + os << " template\n"; + os << " class ExternalModel : public FallbackModel {};\n"; } static void emitInterfaceDecl(const Availability &availability, diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -1,11 +1,17 @@ add_mlir_unittest(MLIRIRTests AttributeTest.cpp DialectTest.cpp + InterfaceAttachmentTest.cpp MemRefTypeTest.cpp OperationSupportTest.cpp ShapedTypeTest.cpp SubElementInterfaceTest.cpp + + DEPENDS + MLIRTestInterfaceIncGen ) +target_include_directories(MLIRIRTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test") target_link_libraries(MLIRIRTests PRIVATE - MLIRIR) + MLIRIR + MLIRTestDialect) diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -0,0 +1,153 @@ +//===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This implements the tests for attaching interfaces to attributes and types +// without having to specify them on the attribute or type class directly. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestTypes.h" + +using namespace mlir; +using namespace mlir::test; + +namespace { + +/// External interface model for the integer type. Only provides non-default +/// methods. +struct Model + : public TestExternalTypeInterface::ExternalModel { + unsigned getBitwidthPlusArg(Type type, unsigned arg) const { + return type.getIntOrFloatBitWidth() + arg; + } + + static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } +}; + +/// External interface model for the float type. Provides non-deafult and +/// overrides default methods. +struct OverridingModel + : public TestExternalTypeInterface::ExternalModel { + unsigned getBitwidthPlusArg(Type type, unsigned arg) const { + return type.getIntOrFloatBitWidth() + arg; + } + + static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } + + unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { + return 128; + } + + static unsigned staticGetArgument(unsigned arg) { return 420; } +}; + +TEST(InterfaceAttachment, Type) { + MLIRContext context; + + // Check that the type has no interface. + IntegerType i8 = IntegerType::get(&context, 8); + ASSERT_FALSE(i8.isa()); + + // Attach an interface and check that the type now has the interface. + IntegerType::attachInterface(context); + TestExternalTypeInterface iface = i8.dyn_cast(); + ASSERT_TRUE(iface != nullptr); + EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); + EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); + EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); + EXPECT_EQ(iface.staticGetArgument(17), 17u); + + // Same, but with the default implementation overridden. + FloatType flt = Float32Type::get(&context); + ASSERT_FALSE(flt.isa()); + Float32Type::attachInterface(context); + iface = flt.dyn_cast(); + ASSERT_TRUE(iface != nullptr); + EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); + EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); + EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); + EXPECT_EQ(iface.staticGetArgument(17), 420u); + + // Other contexts shouldn't have the attribute attached. + MLIRContext other; + IntegerType i8other = IntegerType::get(&other, 8); + EXPECT_FALSE(i8other.isa()); +} + +/// The interface provides a default implementation that expects +/// ConcreteType::getWidth to exist, which is the case for IntegerType. So this +/// just derives from the ExternalModel. +struct TestExternalFallbackTypeIntegerModel + : public TestExternalFallbackTypeInterface::ExternalModel< + TestExternalFallbackTypeIntegerModel, IntegerType> {}; + +/// The interface provides a default implementation that expects +/// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use +/// FallbackModel instead to override this and make sure the code still compiles +/// because we never instantiate the ExternalModel class template with a +/// template argument that would have led to compilation failures. +struct TestExternalFallbackTypeVectorModel + : public TestExternalFallbackTypeInterface::FallbackModel< + TestExternalFallbackTypeVectorModel> { + unsigned getBitwidth(Type type) const { + IntegerType elementType = type.cast() + .getElementType() + .dyn_cast_or_null(); + return elementType ? elementType.getWidth() : 0; + } +}; + +TEST(InterfaceAttachment, Fallback) { + MLIRContext context; + + // Just check that we can attach the interface. + IntegerType i8 = IntegerType::get(&context, 8); + ASSERT_FALSE(i8.isa()); + IntegerType::attachInterface(context); + ASSERT_TRUE(i8.isa()); + + // Call the method so it is guaranteed not to be instantiated. + VectorType vec = VectorType::get({42}, i8); + ASSERT_FALSE(vec.isa()); + VectorType::attachInterface(context); + ASSERT_TRUE(vec.isa()); + EXPECT_EQ(vec.cast().getBitwidth(), 8u); +} + +/// External model for attribute interfaces. +struct TextExternalIntegerAttrModel + : public TestExternalAttrInterface::ExternalModel< + TextExternalIntegerAttrModel, IntegerAttr> { + const Dialect *getDialectPtr(Attribute attr) const { + return &attr.cast().getDialect(); + } + + static int getSomeNumber() { return 42; } +}; + +TEST(InterfaceAttachment, Attribute) { + MLIRContext context; + + // Attribute interfaces use the exact same mechanism as types, so just check + // that the basics work for attributes. + IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); + ASSERT_FALSE(attr.isa()); + IntegerAttr::attachInterface(context); + auto iface = attr.dyn_cast(); + ASSERT_TRUE(iface != nullptr); + EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); + EXPECT_EQ(iface.getSomeNumber(), 42); +} + +} // end namespace