diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -336,12 +336,15 @@ /// The definition of a dynamic op. A dynamic op is an op that is defined at /// runtime, and that can be registered at runtime by an extensible dialect (a -/// dialect inheriting ExtensibleDialect). This class stores the functions that -/// are in the OperationName class, and in addition defines the TypeID of the op -/// that will be defined. -/// Each dynamic operation definition refers to one instance of this class. -class DynamicOpDefinition { +/// dialect inheriting ExtensibleDialect). This class implements the method +/// exposed by the OperationName class, and in addition defines the TypeID of +/// the op that will be defined. Each dynamic operation definition refers to one +/// instance of this class. +class DynamicOpDefinition : public OperationName::Impl { public: + using GetCanonicalizationPatternsFn = + llvm::unique_function; + /// Create a new op at runtime. The op is registered only after passing it to /// the dialect using registerDynamicOp. static std::unique_ptr @@ -361,8 +364,7 @@ OperationName::ParseAssemblyFn &&parseFn, OperationName::PrintAssemblyFn &&printFn, OperationName::FoldHookFn &&foldHookFn, - OperationName::GetCanonicalizationPatternsFn - &&getCanonicalizationPatternsFn, + GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn, OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn); /// Returns the op typeID. @@ -400,9 +402,8 @@ /// Set the hook returning any canonicalization pattern rewrites that the op /// supports, for use by the canonicalization pass. - void - setGetCanonicalizationPatternsFn(OperationName::GetCanonicalizationPatternsFn - &&getCanonicalizationPatterns) { + void setGetCanonicalizationPatternsFn( + GetCanonicalizationPatternsFn &&getCanonicalizationPatterns) { getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns); } @@ -412,6 +413,29 @@ populateDefaultAttrsFn = std::move(populateDefaultAttrs); } + LogicalResult foldHook(Operation *op, ArrayRef attrs, + SmallVectorImpl &results) final { + return foldHookFn(op, attrs, results); + } + void getCanonicalizationPatterns(RewritePatternSet &set, + MLIRContext *context) final { + getCanonicalizationPatternsFn(set, context); + } + bool hasTrait(TypeID id) final { return false; } + OperationName::ParseAssemblyFn getParseAssemblyFn() final { return parseFn; } + void populateDefaultAttrs(const OperationName &name, + NamedAttrList &attrs) final { + populateDefaultAttrsFn(name, attrs); + } + void printAssembly(Operation *op, OpAsmPrinter &printer, + StringRef name) final { + printFn(op, printer, name); + } + LogicalResult verifyInvariants(Operation *op) final { return verifyFn(op); } + LogicalResult verifyRegionInvariants(Operation *op) final { + return verifyRegionFn(op); + } + private: DynamicOpDefinition( StringRef name, ExtensibleDialect *dialect, @@ -420,26 +444,18 @@ OperationName::ParseAssemblyFn &&parseFn, OperationName::PrintAssemblyFn &&printFn, OperationName::FoldHookFn &&foldHookFn, - OperationName::GetCanonicalizationPatternsFn - &&getCanonicalizationPatternsFn, + GetCanonicalizationPatternsFn &&getCanonicalizationPatternsFn, OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn); - /// Unique identifier for this operation. - TypeID typeID; - - /// Name of the operation. - /// The name is prefixed with the dialect name. - std::string name; - /// Dialect defining this operation. - ExtensibleDialect *dialect; + ExtensibleDialect *getdialect(); OperationName::VerifyInvariantsFn verifyFn; OperationName::VerifyRegionInvariantsFn verifyRegionFn; OperationName::ParseAssemblyFn parseFn; OperationName::PrintAssemblyFn printFn; OperationName::FoldHookFn foldHookFn; - OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; + GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn; friend ExtensibleDialect; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -183,8 +183,7 @@ MLIRContext *context) {} /// This hook populates any unset default attrs. - static void populateDefaultAttrs(const RegisteredOperationName &, - NamedAttrList &) {} + static void populateDefaultAttrs(const OperationName &, NamedAttrList &) {} protected: /// If the concrete type didn't implement a custom verifier hook, just fall @@ -1831,20 +1830,11 @@ return result; } - /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook. - static OperationName::GetCanonicalizationPatternsFn - getGetCanonicalizationPatternsFn() { - return &ConcreteType::getCanonicalizationPatterns; - } /// Implementation of `GetHasTraitFn` static OperationName::HasTraitFn getHasTraitFn() { return [](TypeID id) { return op_definition_impl::hasTrait(id); }; } - /// Implementation of `ParseAssemblyFn` OperationName hook. - static OperationName::ParseAssemblyFn getParseAssemblyFn() { - return &ConcreteType::parse; - } /// Implementation of `PrintAssemblyFn` OperationName hook. static OperationName::PrintAssemblyFn getPrintAssemblyFn() { if constexpr (detect_has_print::value) diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -505,11 +505,9 @@ /// Sets default attributes on unset attributes. void populateDefaultAttrs() { - if (auto registered = getRegisteredInfo()) { NamedAttrList attrs(getAttrDictionary()); - registered->populateDefaultAttrs(attrs); + name.populateDefaultAttrs(attrs); setAttrs(attrs.getDictionary(getContext())); - } } //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -23,6 +23,7 @@ #include "mlir/Support/InterfaceSupport.h" #include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include "llvm/Support/TrailingObjects.h" #include @@ -63,17 +64,15 @@ class OperationName { public: - using GetCanonicalizationPatternsFn = - llvm::unique_function; using FoldHookFn = llvm::unique_function, SmallVectorImpl &) const>; using HasTraitFn = llvm::unique_function; using ParseAssemblyFn = - llvm::unique_function; + llvm::function_ref; // Note: RegisteredOperationName is passed as reference here as the derived // class is defined below. - using PopulateDefaultAttrsFn = llvm::unique_function; + using PopulateDefaultAttrsFn = + llvm::unique_function; using PrintAssemblyFn = llvm::unique_function; using VerifyInvariantsFn = @@ -81,63 +80,132 @@ using VerifyRegionInvariantsFn = llvm::unique_function; -protected: /// This class represents a type erased version of an operation. It contains /// all of the components necessary for opaquely interacting with an /// operation. If the operation is not registered, some of these components /// may not be populated. - struct Impl { - Impl(StringAttr name) - : name(name), dialect(nullptr), interfaceMap(std::nullopt) {} + struct InterfaceConcept { + virtual ~InterfaceConcept() = default; + virtual LogicalResult foldHook(Operation *, ArrayRef, + SmallVectorImpl &) = 0; + virtual void getCanonicalizationPatterns(RewritePatternSet &, + MLIRContext *) = 0; + virtual bool hasTrait(TypeID) = 0; + virtual OperationName::ParseAssemblyFn getParseAssemblyFn() = 0; + virtual void populateDefaultAttrs(const OperationName &, + NamedAttrList &) = 0; + virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0; + virtual LogicalResult verifyInvariants(Operation *) = 0; + virtual LogicalResult verifyRegionInvariants(Operation *) = 0; + }; + +public: + class Impl : public InterfaceConcept { + public: + Impl(StringRef, Dialect *dialect, TypeID typeID, + detail::InterfaceMap interfaceMap); + Impl(StringAttr name, Dialect *dialect, TypeID typeID, + detail::InterfaceMap interfaceMap) + : name(name), typeID(typeID), dialect(dialect), + interfaceMap(std::move(interfaceMap)) {} + + /// Returns true if this is a registered operation. + bool isRegistered() const { return typeID != TypeID::get(); } + detail::InterfaceMap &getInterfaceMap() { return interfaceMap; } + Dialect *getDialect() const { return dialect; } + StringAttr getName() const { return name; } + TypeID getTypeID() const { return typeID; } + ArrayRef getAttributeNames() const { return attributeNames; } + + protected: + //===------------------------------------------------------------------===// + // Registered Operation Info /// The name of the operation. StringAttr name; - //===------------------------------------------------------------------===// - // Registered Operation Info + /// The unique identifier of the derived Op class. + TypeID typeID; /// The following fields are only populated when the operation is /// registered. - /// Returns true if the operation has been registered, i.e. if the - /// registration info has been populated. - bool isRegistered() const { return dialect; } - /// This is the dialect that this operation belongs to. Dialect *dialect; - /// The unique identifier of the derived Op class. - TypeID typeID; - /// A map of interfaces that were registered to this operation. detail::InterfaceMap interfaceMap; - /// Internal callback hooks provided by the op implementation. - FoldHookFn foldHookFn; - GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; - HasTraitFn hasTraitFn; - ParseAssemblyFn parseAssemblyFn; - PopulateDefaultAttrsFn populateDefaultAttrsFn; - PrintAssemblyFn printAssemblyFn; - VerifyInvariantsFn verifyInvariantsFn; - VerifyRegionInvariantsFn verifyRegionInvariantsFn; - /// A list of attribute names registered to this operation in StringAttr /// form. This allows for operation classes to use StringAttr for attribute /// lookup/creation/etc., as opposed to raw strings. ArrayRef attributeNames; + + friend class RegisteredOperationName; + }; + +protected: + /// Default implementation for unregistered operations. + struct UnregisteredOpModel : public Impl { + using Impl::Impl; + LogicalResult foldHook(Operation *, ArrayRef, + SmallVectorImpl &) final; + void getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) final; + bool hasTrait(TypeID) final; + virtual OperationName::ParseAssemblyFn getParseAssemblyFn() final; + void populateDefaultAttrs(const OperationName &, NamedAttrList &) final; + void printAssembly(Operation *, OpAsmPrinter &, StringRef) final; + LogicalResult verifyInvariants(Operation *) final; + LogicalResult verifyRegionInvariants(Operation *) final; }; public: OperationName(StringRef name, MLIRContext *context); /// Return if this operation is registered. - bool isRegistered() const { return impl->isRegistered(); } + bool isRegistered() const { return getImpl()->isRegistered(); } + + /// Return the unique identifier of the derived Op class, or null if not + /// registered. + TypeID getTypeID() const { return getImpl()->getTypeID(); } /// If this operation is registered, returns the registered information, /// std::nullopt otherwise. Optional getRegisteredInfo() const; + /// This hook implements a generalized folder for this operation. Operations + /// can implement this to provide simplifications rules that are applied by + /// the Builder::createOrFold API and the canonicalization pass. + /// + /// This is an intentionally limited interface - implementations of this + /// hook can only perform the following changes to the operation: + /// + /// 1. They can leave the operation alone and without changing the IR, and + /// return failure. + /// 2. They can mutate the operation in place, without changing anything + /// else + /// in the IR. In this case, return success. + /// 3. They can return a list of existing values that can be used instead + /// of + /// the operation. In this case, fill in the results list and return + /// success. The caller will remove the operation and use those results + /// instead. + /// + /// This allows expression of some simple in-place canonicalizations (e.g. + /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as + /// generalized constant folding. + LogicalResult foldHook(Operation *op, ArrayRef operands, + SmallVectorImpl &results) const { + return getImpl()->foldHook(op, operands, results); + } + + /// This hook returns any canonicalization pattern rewrites that the + /// operation supports, for use by the canonicalization pass. + void getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) const { + return getImpl()->getCanonicalizationPatterns(results, context); + } + /// Returns true if the operation was registered with a particular trait, e.g. /// hasTrait(). Returns false if the operation /// is unregistered. @@ -145,9 +213,7 @@ bool hasTrait() const { return hasTrait(TypeID::get()); } - bool hasTrait(TypeID traitID) const { - return isRegistered() && impl->hasTraitFn(traitID); - } + bool hasTrait(TypeID traitID) const { return getImpl()->hasTrait(traitID); } /// Returns true if the operation *might* have the provided trait. This /// means that either the operation is unregistered, or it was registered with @@ -157,7 +223,54 @@ return mightHaveTrait(TypeID::get()); } bool mightHaveTrait(TypeID traitID) const { - return !isRegistered() || impl->hasTraitFn(traitID); + return !isRegistered() || getImpl()->hasTrait(traitID); + } + + /// Return the static hook for parsing this operation assembly. + ParseAssemblyFn getParseAssemblyFn() const { + return getImpl()->getParseAssemblyFn(); + } + + /// This hook implements the method to populate defaults attributes that are + /// unset. + void populateDefaultAttrs(NamedAttrList &attrs) const { + getImpl()->populateDefaultAttrs(*this, attrs); + } + + /// This hook implements the AsmPrinter for this operation. + void printAssembly(Operation *op, OpAsmPrinter &p, + StringRef defaultDialect) const { + return getImpl()->printAssembly(op, p, defaultDialect); + } + + /// These hooks implement the verifiers for this operation. It should emits + /// an error message and returns failure if a problem is detected, or + /// returns success if everything is ok. + LogicalResult verifyInvariants(Operation *op) const { + return getImpl()->verifyInvariants(op); + } + LogicalResult verifyRegionInvariants(Operation *op) const { + return getImpl()->verifyRegionInvariants(op); + } + + /// Return the list of cached attribute names registered to this operation. + /// The order of attributes cached here is unique to each type of operation, + /// and the interpretation of this attribute list should generally be driven + /// by the respective operation. In many cases, this caching removes the + /// need to use the raw string name of a known attribute. + /// + /// For example the ODS generator, with an op defining the following + /// attributes: + /// + /// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2); + /// + /// ... may produce an order here of ["attr1", "attr2"]. This allows for the + /// ODS generator to directly access the cached name for a known attribute, + /// greatly simplifying the cost and complexity of attribute usage produced + /// by the generator. + /// + ArrayRef getAttributeNames() const { + return getImpl()->getAttributeNames(); } /// Returns an instance of the concept object for the given interface if it @@ -165,7 +278,13 @@ /// directly. template typename T::Concept *getInterface() const { - return impl->interfaceMap.lookup(); + return getImpl()->getInterfaceMap().lookup(); + } + + /// Attach the given models as implementations of the corresponding + /// interfaces for the concrete operation. + template void attachInterface() { + getImpl()->getInterfaceMap().insert(); } /// Returns true if this operation has the given interface registered to it. @@ -174,7 +293,7 @@ return hasInterface(TypeID::get()); } bool hasInterface(TypeID interfaceID) const { - return impl->interfaceMap.contains(interfaceID); + return getImpl()->getInterfaceMap().contains(interfaceID); } /// Returns true if the operation *might* have the provided interface. This @@ -191,7 +310,8 @@ /// Return the dialect this operation is registered to if the dialect is /// loaded in the context, or nullptr if the dialect isn't loaded. Dialect *getDialect() const { - return isRegistered() ? impl->dialect : impl->name.getReferencedDialect(); + return isRegistered() ? getImpl()->getDialect() + : getImpl()->getName().getReferencedDialect(); } /// Return the name of the dialect this operation is registered to. @@ -204,7 +324,7 @@ StringRef getStringRef() const { return getIdentifier(); } /// Return the name of this operation as a StringAttr. - StringAttr getIdentifier() const { return impl->name; } + StringAttr getIdentifier() const { return getImpl()->getName(); } void print(raw_ostream &os) const; void dump() const; @@ -222,12 +342,17 @@ protected: OperationName(Impl *impl) : impl(impl) {} + Impl *getImpl() const { return impl; } + void setImpl(Impl *rhs) { impl = rhs; } +private: /// The internal implementation of the operation name. - Impl *impl; + Impl *impl = nullptr; /// Allow access to the Impl struct. friend MLIRContextImpl; + friend DenseMapInfo; + friend DenseMapInfo; }; inline raw_ostream &operator<<(raw_ostream &os, OperationName info) { @@ -250,137 +375,62 @@ /// the concrete operation types. class RegisteredOperationName : public OperationName { public: + /// Implementation of the InterfaceConcept for operation APIs that forwarded + /// to a concrete op implementation. + template struct Model : public Impl { + Model(Dialect *dialect) + : Impl(ConcreteOp::getOperationName(), dialect, + TypeID::get(), ConcreteOp::getInterfaceMap()) {} + LogicalResult foldHook(Operation *op, ArrayRef attrs, + SmallVectorImpl &results) final { + return ConcreteOp::getFoldHookFn()(op, attrs, results); + } + void getCanonicalizationPatterns(RewritePatternSet &set, + MLIRContext *context) final { + ConcreteOp::getCanonicalizationPatterns(set, context); + } + bool hasTrait(TypeID id) final { return ConcreteOp::getHasTraitFn()(id); } + OperationName::ParseAssemblyFn getParseAssemblyFn() final { + return ConcreteOp::parse; + } + void populateDefaultAttrs(const OperationName &name, + NamedAttrList &attrs) final { + ConcreteOp::populateDefaultAttrs(name, attrs); + } + void printAssembly(Operation *op, OpAsmPrinter &printer, + StringRef name) final { + ConcreteOp::getPrintAssemblyFn()(op, printer, name); + } + LogicalResult verifyInvariants(Operation *op) final { + return ConcreteOp::getVerifyInvariantsFn()(op); + } + LogicalResult verifyRegionInvariants(Operation *op) final { + return ConcreteOp::getVerifyRegionInvariantsFn()(op); + } + }; + /// Lookup the registered operation information for the given operation. /// Returns std::nullopt if the operation isn't registered. static Optional lookup(StringRef name, MLIRContext *ctx); /// Register a new operation in a Dialect object. - /// This constructor is used by Dialect objects when they register the list of - /// operations they contain. - template - static void insert(Dialect &dialect) { - insert(T::getOperationName(), dialect, TypeID::get(), - T::getParseAssemblyFn(), T::getPrintAssemblyFn(), - T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(), - T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(), - T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(), - T::getPopulateDefaultAttrsFn()); + /// This constructor is used by Dialect objects when they register the list + /// of operations they contain. + template static void insert(Dialect &dialect) { + insert(std::make_unique>(&dialect), T::getAttributeNames()); } /// The use of this method is in general discouraged in favor of /// 'insert(dialect)'. - static void - insert(StringRef name, Dialect &dialect, TypeID typeID, - ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, - VerifyInvariantsFn &&verifyInvariants, - VerifyRegionInvariantsFn &&verifyRegionInvariants, - FoldHookFn &&foldHook, - GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, - ArrayRef attrNames, - PopulateDefaultAttrsFn &&populateDefaultAttrs); + static void insert(std::unique_ptr ownedImpl, + ArrayRef attrNames); /// Return the dialect this operation is registered to. - Dialect &getDialect() const { return *impl->dialect; } - - /// Return the unique identifier of the derived Op class. - TypeID getTypeID() const { return impl->typeID; } + Dialect &getDialect() const { return *getImpl()->getDialect(); } /// Use the specified object to parse this ops custom assembly format. ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const; - /// Return the static hook for parsing this operation assembly. - const ParseAssemblyFn &getParseAssemblyFn() const { - return impl->parseAssemblyFn; - } - - /// This hook implements the AsmPrinter for this operation. - void printAssembly(Operation *op, OpAsmPrinter &p, - StringRef defaultDialect) const { - return impl->printAssemblyFn(op, p, defaultDialect); - } - - /// These hooks implement the verifiers for this operation. It should emits - /// an error message and returns failure if a problem is detected, or returns - /// success if everything is ok. - LogicalResult verifyInvariants(Operation *op) const { - return impl->verifyInvariantsFn(op); - } - LogicalResult verifyRegionInvariants(Operation *op) const { - return impl->verifyRegionInvariantsFn(op); - } - - /// This hook implements a generalized folder for this operation. Operations - /// can implement this to provide simplifications rules that are applied by - /// the Builder::createOrFold API and the canonicalization pass. - /// - /// This is an intentionally limited interface - implementations of this hook - /// can only perform the following changes to the operation: - /// - /// 1. They can leave the operation alone and without changing the IR, and - /// return failure. - /// 2. They can mutate the operation in place, without changing anything else - /// in the IR. In this case, return success. - /// 3. They can return a list of existing values that can be used instead of - /// the operation. In this case, fill in the results list and return - /// success. The caller will remove the operation and use those results - /// instead. - /// - /// This allows expression of some simple in-place canonicalizations (e.g. - /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as - /// generalized constant folding. - LogicalResult foldHook(Operation *op, ArrayRef operands, - SmallVectorImpl &results) const { - return impl->foldHookFn(op, operands, results); - } - - /// This hook returns any canonicalization pattern rewrites that the operation - /// supports, for use by the canonicalization pass. - void getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) const { - return impl->getCanonicalizationPatternsFn(results, context); - } - - /// Attach the given models as implementations of the corresponding interfaces - /// for the concrete operation. - template - void attachInterface() { - impl->interfaceMap.insert(); - } - - /// Returns true if the operation has a particular trait. - template