diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -336,11 +336,11 @@ /// The definition of a dynamic op. A dynamic op is an op that is defined at /// runtime, and that can be registered at runtime by an extensible dialect (a -/// dialect inheriting ExtensibleDialect). This class stores the functions that -/// are in the OperationName class, and in addition defines the TypeID of the op -/// that will be defined. -/// Each dynamic operation definition refers to one instance of this class. -class DynamicOpDefinition { +/// dialect inheriting ExtensibleDialect). This class implements the method +/// exposed by the OperationName class, and in addition defines the TypeID of +/// the op that will be defined. Each dynamic operation definition refers to one +/// instance of this class. +class DynamicOpDefinition : public OperationName::Impl { public: /// Create a new op at runtime. The op is registered only after passing it to /// the dialect using registerDynamicOp. @@ -412,6 +412,29 @@ populateDefaultAttrsFn = std::move(populateDefaultAttrs); } + LogicalResult foldHook(Operation *op, ArrayRef attrs, + SmallVectorImpl &results) final { + return foldHookFn(op, attrs, results); + } + void getCanonicalizationPatterns(RewritePatternSet &set, + MLIRContext *context) final { + getCanonicalizationPatternsFn(set, context); + } + bool hasTrait(TypeID id) final { return false; } + OperationName::ParseAssemblyFn getParseAssemblyFn() final { return parseFn; } + void populateDefaultAttrs(const RegisteredOperationName &name, + NamedAttrList &attrs) final { + populateDefaultAttrsFn(name, attrs); + } + void printAssembly(Operation *op, OpAsmPrinter &printer, + StringRef name) final { + printFn(op, printer, name); + } + LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); } + LogicalResult verifyRegionInvariants(Operation *op) final { + return verifyRegionFn(op); + } + private: DynamicOpDefinition( StringRef name, ExtensibleDialect *dialect, @@ -424,15 +447,8 @@ &&getCanonicalizationPatternsFn, OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn); - /// Unique identifier for this operation. - TypeID typeID; - - /// Name of the operation. - /// The name is prefixed with the dialect name. - std::string name; - /// Dialect defining this operation. - ExtensibleDialect *dialect; + ExtensibleDialect *getdialect(); OperationName::VerifyInvariantsFn verifyFn; OperationName::VerifyRegionInvariantsFn verifyRegionFn; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1810,10 +1810,6 @@ return [](TypeID id) { return op_definition_impl::hasTrait(id); }; } - /// Implementation of `ParseAssemblyFn` OperationName hook. - static OperationName::ParseAssemblyFn getParseAssemblyFn() { - return &ConcreteType::parse; - } /// Implementation of `PrintAssemblyFn` OperationName hook. static OperationName::PrintAssemblyFn getPrintAssemblyFn() { if constexpr (detect_has_print::value) diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -23,6 +23,7 @@ #include "mlir/Support/InterfaceSupport.h" #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/TrailingObjects.h" #include @@ -69,7 +70,7 @@ Operation *, ArrayRef, SmallVectorImpl &) const>; using HasTraitFn = llvm::unique_function; using ParseAssemblyFn = - llvm::unique_function; + llvm::function_ref; // Note: RegisteredOperationName is passed as reference here as the derived // class is defined below. using PopulateDefaultAttrsFn = llvm::unique_function; -protected: /// This class represents a type erased version of an operation. It contains /// all of the components necessary for opaquely interacting with an /// operation. If the operation is not registered, some of these components /// may not be populated. - struct Impl { - Impl(StringAttr name) - : name(name), dialect(nullptr), interfaceMap(std::nullopt) {} + struct InterfaceConcept { + virtual ~InterfaceConcept() = default; + virtual LogicalResult foldHook(Operation *, ArrayRef, + SmallVectorImpl &) = 0; + virtual void getCanonicalizationPatterns(RewritePatternSet &, + MLIRContext *) = 0; + virtual bool hasTrait(TypeID) = 0; + virtual OperationName::ParseAssemblyFn getParseAssemblyFn() = 0; + virtual void populateDefaultAttrs(const RegisteredOperationName &, + NamedAttrList &) = 0; + virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0; + virtual LogicalResult verifyInvariants(Operation *) = 0; + virtual LogicalResult verifyRegionInvariants(Operation *) = 0; + }; + +public: + class Impl : public InterfaceConcept { + public: + Impl(StringRef, Dialect *dialect, TypeID typeID, + detail::InterfaceMap interfaceMap); + Impl(StringAttr name, Dialect *dialect, TypeID typeID, + detail::InterfaceMap interfaceMap) + : name(name), typeID(typeID), dialect(dialect), + interfaceMap(std::move(interfaceMap)) {} + + /// Returns true if this is a registered operation. + bool isRegistered() const { return typeID != TypeID::get(); } + detail::InterfaceMap &getInterfaceMap() { return interfaceMap; } + Dialect *getDialect() const { return dialect; } + StringAttr getName() const { return name; } + TypeID getTypeID() const { return typeID; } + ArrayRef getAttributeNames() const { return attributeNames; } + + protected: + //===------------------------------------------------------------------===// + // Registered Operation Info /// The name of the operation. StringAttr name; - //===------------------------------------------------------------------===// - // Registered Operation Info + /// The unique identifier of the derived Op class. + TypeID typeID; /// The following fields are only populated when the operation is /// registered. - /// Returns true if the operation has been registered, i.e. if the - /// registration info has been populated. - bool isRegistered() const { return dialect; } - /// This is the dialect that this operation belongs to. Dialect *dialect; - /// The unique identifier of the derived Op class. - TypeID typeID; - /// A map of interfaces that were registered to this operation. detail::InterfaceMap interfaceMap; - /// Internal callback hooks provided by the op implementation. - FoldHookFn foldHookFn; - GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; - HasTraitFn hasTraitFn; - ParseAssemblyFn parseAssemblyFn; - PopulateDefaultAttrsFn populateDefaultAttrsFn; - PrintAssemblyFn printAssemblyFn; - VerifyInvariantsFn verifyInvariantsFn; - VerifyRegionInvariantsFn verifyRegionInvariantsFn; - /// A list of attribute names registered to this operation in StringAttr /// form. This allows for operation classes to use StringAttr for attribute /// lookup/creation/etc., as opposed to raw strings. ArrayRef attributeNames; + + friend class RegisteredOperationName; + }; + +protected: + /// Default implementation for unregistered operations. + struct UnregisteredOpModel : public Impl { + using Impl::Impl; + LogicalResult foldHook(Operation *, ArrayRef, + SmallVectorImpl &) final; + void getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) final; + bool hasTrait(TypeID) final; + virtual OperationName::ParseAssemblyFn getParseAssemblyFn() final; + void populateDefaultAttrs(const RegisteredOperationName &, + NamedAttrList &) final; + void printAssembly(Operation *, OpAsmPrinter &, StringRef) final; + LogicalResult verifyInvariants(Operation *) final; + LogicalResult verifyRegionInvariants(Operation *) final; }; public: OperationName(StringRef name, MLIRContext *context); /// Return if this operation is registered. - bool isRegistered() const { return impl->isRegistered(); } + bool isRegistered() const { return getImpl()->isRegistered(); } /// If this operation is registered, returns the registered information, /// std::nullopt otherwise. @@ -145,9 +179,7 @@ bool hasTrait() const { return hasTrait(TypeID::get()); } - bool hasTrait(TypeID traitID) const { - return isRegistered() && impl->hasTraitFn(traitID); - } + bool hasTrait(TypeID traitID) const { return getImpl()->hasTrait(traitID); } /// Returns true if the operation *might* have the provided trait. This /// means that either the operation is unregistered, or it was registered with @@ -157,7 +189,7 @@ return mightHaveTrait(TypeID::get()); } bool mightHaveTrait(TypeID traitID) const { - return !isRegistered() || impl->hasTraitFn(traitID); + return !isRegistered() || getImpl()->hasTrait(traitID); } /// Returns an instance of the concept object for the given interface if it @@ -165,7 +197,7 @@ /// directly. template typename T::Concept *getInterface() const { - return impl->interfaceMap.lookup(); + return getImpl()->getInterfaceMap().lookup(); } /// Returns true if this operation has the given interface registered to it. @@ -174,7 +206,7 @@ return hasInterface(TypeID::get()); } bool hasInterface(TypeID interfaceID) const { - return impl->interfaceMap.contains(interfaceID); + return getImpl()->getInterfaceMap().contains(interfaceID); } /// Returns true if the operation *might* have the provided interface. This @@ -191,7 +223,8 @@ /// Return the dialect this operation is registered to if the dialect is /// loaded in the context, or nullptr if the dialect isn't loaded. Dialect *getDialect() const { - return isRegistered() ? impl->dialect : impl->name.getReferencedDialect(); + return isRegistered() ? getImpl()->getDialect() + : getImpl()->getName().getReferencedDialect(); } /// Return the name of the dialect this operation is registered to. @@ -204,7 +237,7 @@ StringRef getStringRef() const { return getIdentifier(); } /// Return the name of this operation as a StringAttr. - StringAttr getIdentifier() const { return impl->name; } + StringAttr getIdentifier() const { return getImpl()->getName(); } void print(raw_ostream &os) const; void dump() const; @@ -222,12 +255,17 @@ protected: OperationName(Impl *impl) : impl(impl) {} + Impl *getImpl() const { return impl; } + void setImpl(Impl *rhs) { impl = rhs; } +private: /// The internal implementation of the operation name. - Impl *impl; + Impl *impl = nullptr; /// Allow access to the Impl struct. friend MLIRContextImpl; + friend DenseMapInfo; + friend DenseMapInfo; }; inline raw_ostream &operator<<(raw_ostream &os, OperationName info) { @@ -250,78 +288,100 @@ /// the concrete operation types. class RegisteredOperationName : public OperationName { public: + /// Implementation of the InterfaceConcept for operation APIs that forwarded + /// to a concrete op implementation. + template struct Model : public Impl { + Model(Dialect *dialect) + : Impl(ConcreteOp::getOperationName(), dialect, + TypeID::get(), ConcreteOp::getInterfaceMap()) {} + LogicalResult foldHook(Operation *op, ArrayRef attrs, + SmallVectorImpl &results) final { + return ConcreteOp::getFoldHookFn()(op, attrs, results); + } + void getCanonicalizationPatterns(RewritePatternSet &set, + MLIRContext *context) final { + ConcreteOp::getGetCanonicalizationPatternsFn()(set, context); + } + bool hasTrait(TypeID id) final { return ConcreteOp::getHasTraitFn()(id); } + OperationName::ParseAssemblyFn getParseAssemblyFn() final { + return ConcreteOp::parse; + } + void populateDefaultAttrs(const RegisteredOperationName &name, + NamedAttrList &attrs) final { + ConcreteOp::populateDefaultAttrs(name, attrs); + } + void printAssembly(Operation *op, OpAsmPrinter &printer, + StringRef name) final { + ConcreteOp::getPrintAssemblyFn()(op, printer, name); + } + LogicalResult verifyInvariants(Operation *op) final { + return ConcreteOp::getVerifyInvariantsFn()(op); + } + LogicalResult verifyRegionInvariants(Operation *op) final { + return ConcreteOp::getVerifyRegionInvariantsFn()(op); + } + }; + /// Lookup the registered operation information for the given operation. /// Returns std::nullopt if the operation isn't registered. static Optional lookup(StringRef name, MLIRContext *ctx); /// Register a new operation in a Dialect object. - /// This constructor is used by Dialect objects when they register the list of - /// operations they contain. - template - static void insert(Dialect &dialect) { - insert(T::getOperationName(), dialect, TypeID::get(), - T::getParseAssemblyFn(), T::getPrintAssemblyFn(), - T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(), - T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(), - T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(), - T::getPopulateDefaultAttrsFn()); + /// This constructor is used by Dialect objects when they register the list + /// of operations they contain. + template static void insert(Dialect &dialect) { + insert(std::make_unique>(&dialect), T::getAttributeNames()); } /// The use of this method is in general discouraged in favor of /// 'insert(dialect)'. - static void - insert(StringRef name, Dialect &dialect, TypeID typeID, - ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, - VerifyInvariantsFn &&verifyInvariants, - VerifyRegionInvariantsFn &&verifyRegionInvariants, - FoldHookFn &&foldHook, - GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, - ArrayRef attrNames, - PopulateDefaultAttrsFn &&populateDefaultAttrs); + static void insert(std::unique_ptr ownedImpl, + ArrayRef attrNames); /// Return the dialect this operation is registered to. - Dialect &getDialect() const { return *impl->dialect; } + Dialect &getDialect() const { return *getImpl()->getDialect(); } /// Return the unique identifier of the derived Op class. - TypeID getTypeID() const { return impl->typeID; } + TypeID getTypeID() const { return getImpl()->getTypeID(); } /// Use the specified object to parse this ops custom assembly format. ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const; /// Return the static hook for parsing this operation assembly. - const ParseAssemblyFn &getParseAssemblyFn() const { - return impl->parseAssemblyFn; + ParseAssemblyFn getParseAssemblyFn() const { + return getImpl()->getParseAssemblyFn(); } /// This hook implements the AsmPrinter for this operation. void printAssembly(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) const { - return impl->printAssemblyFn(op, p, defaultDialect); + return getImpl()->printAssembly(op, p, defaultDialect); } /// These hooks implement the verifiers for this operation. It should emits - /// an error message and returns failure if a problem is detected, or returns - /// success if everything is ok. + /// an error message and returns failure if a problem is detected, or + /// returns success if everything is ok. LogicalResult verifyInvariants(Operation *op) const { - return impl->verifyInvariantsFn(op); + return getImpl()->verifyInvariants(op); } LogicalResult verifyRegionInvariants(Operation *op) const { - return impl->verifyRegionInvariantsFn(op); + return getImpl()->verifyRegionInvariants(op); } - /// This hook implements a generalized folder for this operation. Operations + /// This hook implements a generalized folder for this operation. Operations /// can implement this to provide simplifications rules that are applied by /// the Builder::createOrFold API and the canonicalization pass. /// - /// This is an intentionally limited interface - implementations of this hook - /// can only perform the following changes to the operation: + /// This is an intentionally limited interface - implementations of this + /// hook can only perform the following changes to the operation: /// /// 1. They can leave the operation alone and without changing the IR, and /// return failure. - /// 2. They can mutate the operation in place, without changing anything else + /// 2. They can mutate the operation in place, without changing anything + /// else /// in the IR. In this case, return success. - /// 3. They can return a list of existing values that can be used instead of + /// 3. They can return a list of existing values that can be used instead + /// of /// the operation. In this case, fill in the results list and return /// success. The caller will remove the operation and use those results /// instead. @@ -331,37 +391,35 @@ /// generalized constant folding. LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) const { - return impl->foldHookFn(op, operands, results); + return getImpl()->foldHook(op, operands, results); } - /// This hook returns any canonicalization pattern rewrites that the operation - /// supports, for use by the canonicalization pass. + /// This hook returns any canonicalization pattern rewrites that the + /// operation supports, for use by the canonicalization pass. void getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) const { - return impl->getCanonicalizationPatternsFn(results, context); + return getImpl()->getCanonicalizationPatterns(results, context); } - /// Attach the given models as implementations of the corresponding interfaces - /// for the concrete operation. - template - void attachInterface() { - impl->interfaceMap.insert(); + /// Attach the given models as implementations of the corresponding + /// interfaces for the concrete operation. + template void attachInterface() { + getImpl()->getInterfaceMap().insert(); } /// Returns true if the operation has a particular trait. - template