diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -131,14 +131,14 @@ /// to be overridden. struct Concept { virtual ~Concept(); - virtual unsigned getNumInputs(Operation *op) = 0; + virtual unsigned getNumInputs(Operation *op) const = 0; }; /// Define a model class that specializes a concept on a given operation type. template struct Model : public Concept { /// Override the method to dispatch on the concrete operation. - unsigned getNumInputs(Operation *op) final { + unsigned getNumInputs(Operation *op) const final { return llvm::cast(op).getNumInputs(); } }; @@ -151,7 +151,7 @@ using OpInterface::OpInterface; /// The interface dispatches to 'getImpl()', an instance of the concept. - unsigned getNumInputs() { + unsigned getNumInputs() const { return getImpl()->getNumInputs(getOperation()); } }; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1348,120 +1348,39 @@ traitID); } - /// Returns an opaque pointer to a concept instance of the interface with the - /// given ID if one was registered to this operation. - static void *getRawInterface(TypeID id) { - return InterfaceLookup::template lookup...>(id); - } - - struct InterfaceLookup { - /// Trait to check if T provides a static 'getInterfaceID' method. - template - using has_get_interface_id = decltype(T::getInterfaceID()); - - /// If 'T' is the same interface as 'interfaceID' return the concept - /// instance. - template - static typename std::enable_if< - llvm::is_detected::value, void *>::type - lookup(TypeID interfaceID) { - return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr; - } - - /// 'T' is known to not be an interface, return nullptr. - template - static typename std::enable_if< - !llvm::is_detected::value, void *>::type - lookup(TypeID) { - return nullptr; - } - - template - static void *lookup(TypeID interfaceID) { - auto *concept = lookup(interfaceID); - return concept ? concept : lookup(interfaceID); - } - }; + /// Returns an interface map for the interfaces registered to this operation. + static detail::InterfaceMap getInterfaceMap() { + return detail::InterfaceMap::template get...>(); + } - /// Allow access to 'hasTrait' and 'getRawInterface'. + /// Allow access to 'hasTrait' and 'getInterfaceMap'. friend AbstractOperation; }; -/// This class represents the base of an operation interface. Operation -/// interfaces provide access to derived *Op properties through an opaquely -/// Operation instance. Derived interfaces must also provide a 'Traits' class -/// that defines a 'Concept' and a 'Model' class. The 'Concept' class defines an -/// abstract virtual interface, where as the 'Model' class implements this -/// interface for a specific derived *Op type. Both of these classes *must* not -/// contain non-static data. A simple example is shown below: -/// -/// struct ExampleOpInterfaceTraits { -/// struct Concept { -/// virtual unsigned getNumInputs(Operation *op) = 0; -/// }; -/// template class Model { -/// unsigned getNumInputs(Operation *op) final { -/// return cast(op).getNumInputs(); -/// } -/// }; -/// }; -/// +/// This class represents the base of an operation interface. See the definition +/// of `detail::Interface` for requirements on the `Traits` type. template -class OpInterface : public Op { +class OpInterface + : public detail::Interface, OpTrait::TraitBase> { public: - using Concept = typename Traits::Concept; - template using Model = typename Traits::template Model; using Base = OpInterface; + using InterfaceBase = detail::Interface, OpTrait::TraitBase>; - OpInterface(Operation *op = nullptr) - : Op(op), impl(op ? getInterfaceFor(op) : nullptr) { - assert((!op || impl) && - "instantiating an interface with an unregistered operation"); - } - - /// Support 'classof' by checking if the given operation defines the concrete - /// interface. - static bool classof(Operation *op) { return getInterfaceFor(op); } - - /// Define an accessor for the ID of this interface. - static TypeID getInterfaceID() { return TypeID::get(); } - - /// This is a special trait that registers a given interface with an - /// operation. - template - struct Trait : public OpTrait::TraitBase { - /// Define an accessor for the ID of this interface. - static TypeID getInterfaceID() { return TypeID::get(); } - - /// Provide an accessor to a static instance of the interface model for the - /// concrete operation type. - /// The implementation is inspired from Sean Parent's concept-based - /// polymorphism. A key difference is that the set of classes erased is - /// statically known, which alleviates the need for using dynamic memory - /// allocation. - /// We use a zero-sized templated class `Model` to emit the - /// virtual table and generate a singleton object for each instantiation of - /// this class. - static Concept &instance() { - static Model singleton; - return singleton; - } - }; - -protected: - /// Get the raw concept in the correct derived concept type. - Concept *getImpl() { return impl; } + /// Inherit the base class constructor. + using InterfaceBase::InterfaceBase; private: /// Returns the impl interface instance for the given operation. - static Concept *getInterfaceFor(Operation *op) { + static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { // Access the raw interface from the abstract operation. auto *abstractOp = op->getAbstractOperation(); return abstractOp ? abstractOp->getInterface() : nullptr; } - /// A pointer to the impl concept object. - Concept *impl; + /// Allow access to `getInterfaceFor`. + friend InterfaceBase; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -19,7 +19,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/InterfaceSupport.h" #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/Support/PointerLikeTypeTraits.h" @@ -136,8 +136,7 @@ /// was registered to this operation, null otherwise. This should not be used /// directly. template typename T::Concept *getInterface() const { - return reinterpret_cast( - getRawInterface(T::getInterfaceID())); + return interfaceMap.lookup(); } /// Returns if the operation has a particular trait. @@ -157,7 +156,7 @@ T::getOperationName(), dialect, T::getOperationProperties(), TypeID::get(), T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns, - T::getRawInterface, T::hasTrait); + T::getInterfaceMap(), T::hasTrait); } private: @@ -171,22 +170,19 @@ SmallVectorImpl &results), void (&getCanonicalizationPatterns)(OwningRewritePatternList &results, MLIRContext *context), - void *(&getRawInterface)(TypeID interfaceID), - bool (&hasTrait)(TypeID traitID)) + detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID)) : name(name), dialect(dialect), typeID(typeID), parseAssembly(parseAssembly), printAssembly(printAssembly), verifyInvariants(verifyInvariants), foldHook(foldHook), getCanonicalizationPatterns(getCanonicalizationPatterns), - opProperties(opProperties), getRawInterface(getRawInterface), + opProperties(opProperties), interfaceMap(std::move(interfaceMap)), hasRawTrait(hasTrait) {} /// The properties of the operation. const OperationProperties opProperties; - /// Returns a raw instance of the concept for the given interface id if it is - /// registered to this operation, nullptr otherwise. This should not be used - /// directly. - void *(&getRawInterface)(TypeID interfaceID); + /// A map of interfaces that were registered to this operation. + detail::InterfaceMap interfaceMap; /// This hook returns if the operation contains the trait corresponding /// to the given TypeID. diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -0,0 +1,212 @@ +//===- InterfaceSupport.h - MLIR Interface Support Classes ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines several support classes for defining interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_INTERFACESUPPORT_H +#define MLIR_SUPPORT_INTERFACESUPPORT_H + +#include "mlir/Support/TypeID.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/TypeName.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace detail { +//===----------------------------------------------------------------------===// +// Interface +//===----------------------------------------------------------------------===// + +/// This class represents an abstract interface. An interface is a simplified +/// mechanism for attaching concept based polymorphism to a class hierarchy. An +/// interace is comprised of two components: +/// * The derived interface class: This is what users interact with, and invoke +/// methods on. +/// * An interface `Trait` class: This is the class that is attached to the +/// object implementing the interface. It is the mechanism with which models +/// are specialized. +/// +/// Derived interfaces types must provide the following template types: +/// * ConcreteType: The CRTP derived type. +/// * ValueT: The opaque type the derived interface operates on. For example +/// `Operation*` for operation interfaces, or `Attribute` for +/// attribute interfaces. +/// * Traits: A class that contains definitions for a 'Concept' and a 'Model' +/// class. The 'Concept' class defines an abstract virtual interface, +/// where as the 'Model' class implements this interface for a +/// specific derived T type. Both of these classes *must* not contain +/// non-static data. A simple example is shown below: +/// +/// ```c++ +/// struct ExampleInterfaceTraits { +/// struct Concept { +/// virtual unsigned getNumInputs(T t) const = 0; +/// }; +/// template class Model { +/// unsigned getNumInputs(T t) const final { +/// return cast(t).getNumInputs(); +/// } +/// }; +/// }; +/// ``` +/// +/// * BaseType: A desired base type for the interface. This is a class that +/// provides that provides specific functionality for the `ValueT` +/// value. For instance the specific `Op` that will wrap the +/// `Operation*` for an `OpInterface`. +/// * BaseTrait: The base type for the interface trait. This is the base class +/// to use for the interface trait that will be attached to each +/// instance of `ValueT` that implements this interface. +/// +template class> class BaseTrait> +class Interface : public BaseType { +public: + using Concept = typename Traits::Concept; + template + using Model = typename Traits::template Model; + using InterfaceBase = + Interface; + + Interface(ValueT t = ValueT()) + : BaseType(t), impl(t ? ConcreteType::getInterfaceFor(t) : nullptr) { + assert((!t || impl) && + "instantiating an interface with an unregistered operation"); + } + + /// Support 'classof' by checking if the given object defines the concrete + /// interface. + static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); } + + /// Define an accessor for the ID of this interface. + static TypeID getInterfaceID() { return TypeID::get(); } + + /// This is a special trait that registers a given interface with an object. + template + struct Trait : public BaseTrait { + /// Define an accessor for the ID of this interface. + static TypeID getInterfaceID() { return TypeID::get(); } + + /// Provide an accessor to a static instance of the interface model for the + /// concrete T type. + /// The implementation is inspired from Sean Parent's concept-based + /// polymorphism. A key difference is that the set of classes erased is + /// statically known, which alleviates the need for using dynamic memory + /// allocation. + /// We use a zero-sized templated class `Model` to emit the + /// virtual table and generate a singleton object for each instantiation of + /// this class. + static Concept &instance() { + static Model singleton; + return singleton; + } + }; + +protected: + /// Get the raw concept in the correct derived concept type. + const Concept *getImpl() const { return impl; } + Concept *getImpl() { return impl; } + +private: + /// A pointer to the impl concept object. + Concept *impl; +}; + +//===----------------------------------------------------------------------===// +// InterfaceMap +//===----------------------------------------------------------------------===// + +/// This class provides an efficient mapping between a given `Interface` type, +/// and a particular implementation of its concept. +class InterfaceMap { +public: + /// Construct an InterfaceMap with the given set of template types. For + /// convenience given that object trait lists may contain other non-interface + /// types, not all of the types need to be interfaces. The provided types that + /// do not represent interfaces are not added to the interface map. + template + static InterfaceMap get() { + return InterfaceMap(MapBuilder::create()); + } + + /// Returns an instance of the concept object for the given interface if it + /// was registered to this map, null otherwise. + template + typename T::Concept *lookup() const { + if (!interfaces) + return nullptr; + return reinterpret_cast( + interfaces->lookup(T::getInterfaceID())); + } + +private: + /// This struct provides support for building a map of interfaces. + class MapBuilder { + public: + template + static std::unique_ptr> create() { + // Filter the provided types for those that are interfaces. This reduces + // the amount of maps that are generated. + return createImpl((typename FilterTypes::type *)nullptr); + } + + private: + /// Trait to check if T provides a static 'getInterfaceID' method. + template + using has_get_interface_id = decltype(T::getInterfaceID()); + template + using detect_get_interface_id = llvm::is_detected; + + /// Utility to filter a given sequence of types base upon a predicate. + template + struct FilterTypeT { + template + using type = std::tuple; + }; + template <> + struct FilterTypeT { + template + using type = std::tuple<>; + }; + template