diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -340,7 +340,7 @@ /// are in the OperationName class, and in addition defines the TypeID of the op /// that will be defined. /// Each dynamic operation definition refers to one instance of this class. -class DynamicOpDefinition { +class DynamicOpDefinition : public OperationName::InterfaceConcept { public: /// Create a new op at runtime. The op is registered only after passing it to /// the dialect using registerDynamicOp. @@ -412,6 +412,32 @@ populateDefaultAttrsFn = std::move(populateDefaultAttrs); } + LogicalResult foldHook(Operation *op, ArrayRef attrs, + SmallVectorImpl &results) final { + return foldHookFn(op, attrs, results); + } + void getCanonicalizationPatterns(RewritePatternSet &set, + MLIRContext *context) final { + getCanonicalizationPatternsFn(set, context); + } + bool hasTrait(TypeID id) final { return false; } + OperationName::ParseAssemblyFn getParseAssemblyFn() final { return parseFn; } + void populateDefaultAttrs(const RegisteredOperationName &name, + NamedAttrList &attrs) final { + populateDefaultAttrsFn(name, attrs); + } + void printAssembly(Operation *op, OpAsmPrinter &printer, + StringRef name) final { + printFn(op, printer, name); + } + LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); } + LogicalResult verifyRegionInvariants(Operation *op) final { + return verifyRegionFn(op); + } + detail::InterfaceMap getInterfaceMap() final { + return detail::InterfaceMap::get<>(); + } + private: DynamicOpDefinition( StringRef name, ExtensibleDialect *dialect, @@ -541,6 +567,9 @@ /// This structure allows to get in O(1) a dynamic attribute given its name. llvm::StringMap nameToDynAttrs; + /// All the dynamic operations registered. + std::vector> dynamicOps; + /// Give DynamicOpDefinition access to allocateTypeID. friend DynamicOpDefinition; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1811,10 +1811,6 @@ return [](TypeID id) { return op_definition_impl::hasTrait(id); }; } - /// Implementation of `ParseAssemblyFn` OperationName hook. - static OperationName::ParseAssemblyFn getParseAssemblyFn() { - return &ConcreteType::parse; - } /// Implementation of `PrintAssemblyFn` OperationName hook. static OperationName::PrintAssemblyFn getPrintAssemblyFn() { if constexpr (detect_has_print::value) diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -23,6 +23,7 @@ #include "mlir/Support/InterfaceSupport.h" #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/TrailingObjects.h" #include @@ -69,7 +70,7 @@ Operation *, ArrayRef, SmallVectorImpl &) const>; using HasTraitFn = llvm::unique_function; using ParseAssemblyFn = - llvm::unique_function; + llvm::function_ref; // Note: RegisteredOperationName is passed as reference here as the derived // class is defined below. using PopulateDefaultAttrsFn = llvm::unique_function; -protected: /// This class represents a type erased version of an operation. It contains /// all of the components necessary for opaquely interacting with an /// operation. If the operation is not registered, some of these components /// may not be populated. + struct InterfaceConcept { + virtual ~InterfaceConcept() = default; + virtual LogicalResult foldHook(Operation *, ArrayRef, + SmallVectorImpl &) = 0; + virtual void getCanonicalizationPatterns(RewritePatternSet &, + MLIRContext *) = 0; + virtual bool hasTrait(TypeID) = 0; + virtual OperationName::ParseAssemblyFn getParseAssemblyFn() = 0; + virtual void populateDefaultAttrs(const RegisteredOperationName &, + NamedAttrList &) = 0; + virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0; + virtual LogicalResult verifyInvariants(Operation *) = 0; + virtual LogicalResult verifyRegionInvariants(Operation *) = 0; + virtual detail::InterfaceMap getInterfaceMap() = 0; + }; + + /// Default implementation for unregistered operations. + struct UnregisteredOpModel : InterfaceConcept { + LogicalResult foldHook(Operation *, ArrayRef, + SmallVectorImpl &) final; + void getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) final; + bool hasTrait(TypeID) final; + virtual OperationName::ParseAssemblyFn getParseAssemblyFn() final; + void populateDefaultAttrs(const RegisteredOperationName &, + NamedAttrList &) final; + void printAssembly(Operation *, OpAsmPrinter &, StringRef) final; + LogicalResult verifyInvariants(Operation *) final; + LogicalResult verifyRegionInvariants(Operation *) final; + detail::InterfaceMap getInterfaceMap() final; + }; + +protected: struct Impl { - Impl(StringAttr name) - : name(name), dialect(nullptr), interfaceMap(std::nullopt) {} + Impl(StringAttr name, InterfaceConcept *model) + : name(name), dialect(nullptr), interfaceMap(std::nullopt), + model(model) {} /// The name of the operation. StringAttr name; @@ -113,14 +146,7 @@ detail::InterfaceMap interfaceMap; /// Internal callback hooks provided by the op implementation. - FoldHookFn foldHookFn; - GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; - HasTraitFn hasTraitFn; - ParseAssemblyFn parseAssemblyFn; - PopulateDefaultAttrsFn populateDefaultAttrsFn; - PrintAssemblyFn printAssemblyFn; - VerifyInvariantsFn verifyInvariantsFn; - VerifyRegionInvariantsFn verifyRegionInvariantsFn; + InterfaceConcept *model; /// A list of attribute names registered to this operation in StringAttr /// form. This allows for operation classes to use StringAttr for attribute @@ -132,7 +158,7 @@ OperationName(StringRef name, MLIRContext *context); /// Return if this operation is registered. - bool isRegistered() const { return impl->isRegistered(); } + bool isRegistered() const { return getImpl()->isRegistered(); } /// If this operation is registered, returns the registered information, /// std::nullopt otherwise. @@ -145,9 +171,7 @@ bool hasTrait() const { return hasTrait(TypeID::get()); } - bool hasTrait(TypeID traitID) const { - return isRegistered() && impl->hasTraitFn(traitID); - } + bool hasTrait(TypeID traitID) const { return getModel()->hasTrait(traitID); } /// Returns true if the operation *might* have the provided trait. This /// means that either the operation is unregistered, or it was registered with @@ -157,7 +181,7 @@ return mightHaveTrait(TypeID::get()); } bool mightHaveTrait(TypeID traitID) const { - return !isRegistered() || impl->hasTraitFn(traitID); + return !isRegistered() || getModel()->hasTrait(traitID); } /// Returns an instance of the concept object for the given interface if it @@ -165,7 +189,7 @@ /// directly. template typename T::Concept *getInterface() const { - return impl->interfaceMap.lookup(); + return getImpl()->interfaceMap.lookup(); } /// Returns true if this operation has the given interface registered to it. @@ -174,7 +198,7 @@ return hasInterface(TypeID::get()); } bool hasInterface(TypeID interfaceID) const { - return impl->interfaceMap.contains(interfaceID); + return getImpl()->interfaceMap.contains(interfaceID); } /// Returns true if the operation *might* have the provided interface. This @@ -191,7 +215,8 @@ /// Return the dialect this operation is registered to if the dialect is /// loaded in the context, or nullptr if the dialect isn't loaded. Dialect *getDialect() const { - return isRegistered() ? impl->dialect : impl->name.getReferencedDialect(); + return isRegistered() ? getImpl()->dialect + : getImpl()->name.getReferencedDialect(); } /// Return the name of the dialect this operation is registered to. @@ -204,7 +229,7 @@ StringRef getStringRef() const { return getIdentifier(); } /// Return the name of this operation as a StringAttr. - StringAttr getIdentifier() const { return impl->name; } + StringAttr getIdentifier() const { return getImpl()->name; } void print(raw_ostream &os) const; void dump() const; @@ -221,13 +246,27 @@ bool operator!=(const OperationName &rhs) const { return !(*this == rhs); } protected: - OperationName(Impl *impl) : impl(impl) {} + OperationName(Impl *impl) : impl(impl), model(impl ? impl->model : nullptr) {} + Impl *getImpl() const { return impl; } + InterfaceConcept *getModel() const { + assert(impl->model == model && "Inconsistency in model caching"); + return model; + } + void setImpl(Impl *rhs) { impl = rhs; } + void setModel(InterfaceConcept *rhs) { model = rhs; } +private: /// The internal implementation of the operation name. - Impl *impl; + Impl *impl = nullptr; + + /// Internal callback hooks provided by the op implementation. + /// This is set on the impl, but we cache it here as it'll be access often. + InterfaceConcept *model = nullptr; /// Allow access to the Impl struct. friend MLIRContextImpl; + friend DenseMapInfo; + friend DenseMapInfo; }; inline raw_ostream &operator<<(raw_ostream &os, OperationName info) { @@ -250,78 +289,102 @@ /// the concrete operation types. class RegisteredOperationName : public OperationName { public: + /// Implementation of the InterfaceConcept for operation APIs that forwarded + /// to a concrete op implementation. + template struct Model : public InterfaceConcept { + LogicalResult foldHook(Operation *op, ArrayRef attrs, + SmallVectorImpl &results) final { + return ConcreteOp::getFoldHookFn()(op, attrs, results); + } + void getCanonicalizationPatterns(RewritePatternSet &set, + MLIRContext *context) final { + ConcreteOp::getGetCanonicalizationPatternsFn()(set, context); + } + bool hasTrait(TypeID id) final { return ConcreteOp::getHasTraitFn()(id); } + OperationName::ParseAssemblyFn getParseAssemblyFn() final { + return ConcreteOp::parse; + } + void populateDefaultAttrs(const RegisteredOperationName &name, + NamedAttrList &attrs) final { + ConcreteOp::populateDefaultAttrs(name, attrs); + } + void printAssembly(Operation *op, OpAsmPrinter &printer, + StringRef name) final { + ConcreteOp::getPrintAssemblyFn()(op, printer, name); + } + LogicalResult verifyInvariants(Operation *op) final { + return ConcreteOp::getVerifyInvariantsFn()(op); + } + LogicalResult verifyRegionInvariants(Operation *op) final { + return ConcreteOp::getVerifyRegionInvariantsFn()(op); + } + detail::InterfaceMap getInterfaceMap() final { + return ConcreteOp::getInterfaceMap(); + } + }; + /// Lookup the registered operation information for the given operation. /// Returns std::nullopt if the operation isn't registered. static Optional lookup(StringRef name, MLIRContext *ctx); /// Register a new operation in a Dialect object. - /// This constructor is used by Dialect objects when they register the list of - /// operations they contain. - template - static void insert(Dialect &dialect) { - insert(T::getOperationName(), dialect, TypeID::get(), - T::getParseAssemblyFn(), T::getPrintAssemblyFn(), - T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(), - T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(), - T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(), - T::getPopulateDefaultAttrsFn()); + /// This constructor is used by Dialect objects when they register the list + /// of operations they contain. + template static void insert(Dialect &dialect) { + static Model model; + insert(T::getOperationName(), dialect, TypeID::get(), model, + T::getAttributeNames()); } /// The use of this method is in general discouraged in favor of /// 'insert(dialect)'. - static void - insert(StringRef name, Dialect &dialect, TypeID typeID, - ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, - VerifyInvariantsFn &&verifyInvariants, - VerifyRegionInvariantsFn &&verifyRegionInvariants, - FoldHookFn &&foldHook, - GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, - ArrayRef attrNames, - PopulateDefaultAttrsFn &&populateDefaultAttrs); + static void insert(StringRef name, Dialect &dialect, TypeID typeID, + InterfaceConcept &model, ArrayRef attrNames); /// Return the dialect this operation is registered to. - Dialect &getDialect() const { return *impl->dialect; } + Dialect &getDialect() const { return *getImpl()->dialect; } /// Return the unique identifier of the derived Op class. - TypeID getTypeID() const { return impl->typeID; } + TypeID getTypeID() const { return getImpl()->typeID; } /// Use the specified object to parse this ops custom assembly format. ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const; /// Return the static hook for parsing this operation assembly. - const ParseAssemblyFn &getParseAssemblyFn() const { - return impl->parseAssemblyFn; + ParseAssemblyFn getParseAssemblyFn() const { + return getModel()->getParseAssemblyFn(); } /// This hook implements the AsmPrinter for this operation. void printAssembly(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) const { - return impl->printAssemblyFn(op, p, defaultDialect); + return getModel()->printAssembly(op, p, defaultDialect); } /// These hooks implement the verifiers for this operation. It should emits - /// an error message and returns failure if a problem is detected, or returns - /// success if everything is ok. + /// an error message and returns failure if a problem is detected, or + /// returns success if everything is ok. LogicalResult verifyInvariants(Operation *op) const { - return impl->verifyInvariantsFn(op); + return getModel()->verifyInvariants(op); } LogicalResult verifyRegionInvariants(Operation *op) const { - return impl->verifyRegionInvariantsFn(op); + return getModel()->verifyRegionInvariants(op); } - /// This hook implements a generalized folder for this operation. Operations + /// This hook implements a generalized folder for this operation. Operations /// can implement this to provide simplifications rules that are applied by /// the Builder::createOrFold API and the canonicalization pass. /// - /// This is an intentionally limited interface - implementations of this hook - /// can only perform the following changes to the operation: + /// This is an intentionally limited interface - implementations of this + /// hook can only perform the following changes to the operation: /// /// 1. They can leave the operation alone and without changing the IR, and /// return failure. - /// 2. They can mutate the operation in place, without changing anything else + /// 2. They can mutate the operation in place, without changing anything + /// else /// in the IR. In this case, return success. - /// 3. They can return a list of existing values that can be used instead of + /// 3. They can return a list of existing values that can be used instead + /// of /// the operation. In this case, fill in the results list and return /// success. The caller will remove the operation and use those results /// instead. @@ -331,37 +394,35 @@ /// generalized constant folding. LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) const { - return impl->foldHookFn(op, operands, results); + return getModel()->foldHook(op, operands, results); } - /// This hook returns any canonicalization pattern rewrites that the operation - /// supports, for use by the canonicalization pass. + /// This hook returns any canonicalization pattern rewrites that the + /// operation supports, for use by the canonicalization pass. void getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) const { - return impl->getCanonicalizationPatternsFn(results, context); + return getModel()->getCanonicalizationPatterns(results, context); } - /// Attach the given models as implementations of the corresponding interfaces - /// for the concrete operation. - template - void attachInterface() { - impl->interfaceMap.insert(); + /// Attach the given models as implementations of the corresponding + /// interfaces for the concrete operation. + template void attachInterface() { + getImpl()->interfaceMap.insert(); } /// Returns true if the operation has a particular trait. - template