diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -91,22 +91,10 @@ /// This is a special trait that registers a given interface with an object. template struct Trait : public BaseTrait { + using ModelT = Model; + /// Define an accessor for the ID of this interface. static TypeID getInterfaceID() { return TypeID::get(); } - - /// Provide an accessor to a static instance of the interface model for the - /// concrete T type. - /// The implementation is inspired from Sean Parent's concept-based - /// polymorphism. A key difference is that the set of classes erased is - /// statically known, which alleviates the need for using dynamic memory - /// allocation. - /// We use a zero-sized templated class `Model` to emit the - /// virtual table and generate a singleton object for each instantiation of - /// this class. - static Concept &instance() { - static Model singleton; - return singleton; - } }; protected: @@ -144,60 +132,65 @@ /// This class provides an efficient mapping between a given `Interface` type, /// and a particular implementation of its concept. class InterfaceMap { + /// Trait to check if T provides a static 'getInterfaceID' method. + template + using has_get_interface_id = decltype(T::getInterfaceID()); + template + using detect_get_interface_id = llvm::is_detected; + template + using num_interface_types = typename std::tuple_size< + typename FilterTypes::type>; + public: + InterfaceMap(InterfaceMap &&) = default; + ~InterfaceMap() { + if (interfaces) { + for (auto &it : *interfaces) + free(it.second); + } + } + /// Construct an InterfaceMap with the given set of template types. For /// convenience given that object trait lists may contain other non-interface /// types, not all of the types need to be interfaces. The provided types that /// do not represent interfaces are not added to the interface map. - template static InterfaceMap get() { - return InterfaceMap(MapBuilder::create()); + template + static std::enable_if_t::value != 0, + InterfaceMap> + get() { + // Filter the provided types for those that are interfaces. + using FilteredTupleType = + typename FilterTypes::type; + return getImpl((FilteredTupleType *)nullptr); + } + + template + static std::enable_if_t::value == 0, + InterfaceMap> + get() { + return InterfaceMap(); } /// Returns an instance of the concept object for the given interface if it /// was registered to this map, null otherwise. template typename T::Concept *lookup() const { - if (!interfaces) - return nullptr; - return reinterpret_cast( - interfaces->lookup(T::getInterfaceID())); + void *inst = interfaces ? interfaces->lookup(T::getInterfaceID()) : nullptr; + return reinterpret_cast(inst); } private: - /// This struct provides support for building a map of interfaces. - class MapBuilder { - public: - template - static std::unique_ptr> create() { - // Filter the provided types for those that are interfaces. This reduces - // the amount of maps that are generated. - return createImpl((typename FilterTypes::type *)nullptr); - } - - private: - /// Trait to check if T provides a static 'getInterfaceID' method. - template - using has_get_interface_id = decltype(T::getInterfaceID()); - template - using detect_get_interface_id = llvm::is_detected; - - template - static std::unique_ptr> - createImpl(std::tuple *) { - // Only create an instance of the map if there are any interface types. - if (sizeof...(Ts) == 0) - return std::unique_ptr>(); - - auto map = std::make_unique>(); - (void)std::initializer_list{ - 0, (map->try_emplace(Ts::getInterfaceID(), &Ts::instance()), 0)...}; - return map; - } - }; - -private: - InterfaceMap(std::unique_ptr> interfaces) - : interfaces(std::move(interfaces)) {} + InterfaceMap() = default; + InterfaceMap(MutableArrayRef> elements) + : interfaces(std::make_unique>( + elements.begin(), elements.end())) {} + + template + static InterfaceMap getImpl(std::tuple *) { + std::pair elements[] = {std::make_pair( + Ts::getInterfaceID(), + new (malloc(sizeof(typename Ts::ModelT))) typename Ts::ModelT())...}; + return InterfaceMap(elements); + } /// The internal map of interfaces. This is constructed statically for each /// set of interfaces.