diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -150,12 +150,9 @@ /// This method is used by derived classes to add their operations to the set. /// template <typename... Args> void addOperations() { - (void)std::initializer_list<int>{0, (addOperation<Args>(), 0)...}; - } - template <typename Arg> void addOperation() { - addOperation(AbstractOperation::get<Arg>(*this)); + (void)std::initializer_list<int>{ + 0, (AbstractOperation::insert<Args>(*this), 0)...}; } - void addOperation(AbstractOperation opInfo); /// Register a set of type classes with this dialect. template <typename... Args> void addTypes() { diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -105,6 +105,12 @@ /// Return the operation that this refers to. Operation *getOperation() { return state; } + /// Return the dialect that this refers to. + Dialect *getDialect() { return getOperation()->getDialect(); } + + /// Return the parent Region of this operation. + Region *getParentRegion() { return getOperation()->getParentRegion(); } + /// Returns the closest surrounding operation that contains this operation /// or nullptr if this is a top-level operation. Operation *getParentOp() { return getOperation()->getParentOp(); } @@ -238,7 +244,7 @@ static ParseResult parse(OpAsmParser &parser, OperationState &result); // The fallback for the printer is to print it the generic assembly form. - void print(OpAsmPrinter &p); + static void print(Operation *op, OpAsmPrinter &p); /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, /// so we can cast it away here. @@ -246,6 +252,9 @@ private: Operation *state; + + /// Allow access to internal hook implementation methods. + friend AbstractOperation; }; // Allow comparing operators. @@ -267,117 +276,6 @@ return os; } -/// This template defines the foldHook as used by AbstractOperation. -/// -/// The default implementation uses a general fold method that can be defined on -/// custom ops which can return multiple results. -template <typename ConcreteType, bool isSingleResult, typename = void> -class FoldingHook { -public: - /// This is an implementation detail of the constant folder hook for - /// AbstractOperation. - static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results) { - auto operationFoldResult = cast<ConcreteType>(op).fold(operands, results); - // Failure to fold or in place fold both mean we can continue folding. - if (failed(operationFoldResult) || results.empty()) { - auto traitFoldResult = ConcreteType::foldTraits(op, operands, results); - // Only return the trait fold result if it is a success since - // operationFoldResult might have been a success originally. - if (succeeded(traitFoldResult)) - return traitFoldResult; - } - return operationFoldResult; - } - - /// This hook implements a generalized folder for this operation. Operations - /// can implement this to provide simplifications rules that are applied by - /// the Builder::createOrFold API and the canonicalization pass. - /// - /// This is an intentionally limited interface - implementations of this hook - /// can only perform the following changes to the operation: - /// - /// 1. They can leave the operation alone and without changing the IR, and - /// return failure. - /// 2. They can mutate the operation in place, without changing anything else - /// in the IR. In this case, return success. - /// 3. They can return a list of existing values that can be used instead of - /// the operation. In this case, fill in the results list and return - /// success. The caller will remove the operation and use those results - /// instead. - /// - /// This allows expression of some simple in-place canonicalizations (e.g. - /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as - /// generalized constant folding. - /// - /// If not overridden, this fallback implementation always fails to fold. - /// - LogicalResult fold(ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results) { - return failure(); - } -}; - -/// This template specialization defines the foldHook as used by -/// AbstractOperation for single-result operations. This gives the hook a nicer -/// signature that is easier to implement. -template <typename ConcreteType, bool isSingleResult> -class FoldingHook<ConcreteType, isSingleResult, - typename std::enable_if<isSingleResult>::type> { -public: - /// If the operation returns a single value, then the Op can be implicitly - /// converted to an Value. This yields the value of the only result. - operator Value() { - return static_cast<ConcreteType *>(this)->getOperation()->getResult(0); - } - - /// This is an implementation detail of the constant folder hook for - /// AbstractOperation. - static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results) { - auto result = cast<ConcreteType>(op).fold(operands); - // Failure to fold or in place fold both mean we can continue folding. - if (!result || result.template dyn_cast<Value>() == op->getResult(0)) { - // Only consider the trait fold result if it is a success since - // the operation fold might have been a success originally. - if (auto traitFoldResult = ConcreteType::foldTraits(op, operands)) - result = traitFoldResult; - } - - if (!result) - return failure(); - - // Check if the operation was folded in place. In this case, the operation - // returns itself. - if (result.template dyn_cast<Value>() != op->getResult(0)) - results.push_back(result); - return success(); - } - - /// This hook implements a generalized folder for this operation. Operations - /// can implement this to provide simplifications rules that are applied by - /// the Builder::createOrFold API and the canonicalization pass. - /// - /// This is an intentionally limited interface - implementations of this hook - /// can only perform the following changes to the operation: - /// - /// 1. They can leave the operation alone and without changing the IR, and - /// return nullptr. - /// 2. They can mutate the operation in place, without changing anything else - /// in the IR. In this case, return the operation itself. - /// 3. They can return an existing SSA value that can be used instead of - /// the operation. In this case, return that value. The caller will - /// remove the operation and use that result instead. - /// - /// This allows expression of some simple in-place canonicalizations (e.g. - /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as - /// generalized constant folding. - /// - /// If not overridden, this fallback implementation always fails to fold. - /// - OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; } -}; - //===----------------------------------------------------------------------===// // Operation Trait Types //===----------------------------------------------------------------------===// @@ -441,30 +339,6 @@ auto *base = static_cast<OpState *>(concrete); return base->getOperation(); } - - /// Provide default implementations of trait hooks. This allows traits to - /// provide exactly the overrides they care about. - static LogicalResult verifyTrait(Operation *op) { return success(); } - static AbstractOperation::OperationProperties getTraitProperties() { - return 0; - } - - static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { - SmallVector<OpFoldResult, 1> results; - if (failed(foldTrait(op, operands, results))) - return {}; - if (results.empty()) - return op->getResult(0); - assert(results.size() == 1 && - "Single result op cannot return multiple fold results"); - - return results[0]; - } - - static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results) { - return failure(); - } }; //===----------------------------------------------------------------------===// @@ -738,6 +612,10 @@ Value getResult() { return this->getOperation()->getResult(0); } Type getType() { return getResult().getType(); } + /// If the operation returns a single value, then the Op can be implicitly + /// converted to an Value. This yields the value of the only result. + operator Value() { return getResult(); } + /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. @@ -1306,6 +1184,170 @@ } // end namespace OpTrait +//===----------------------------------------------------------------------===// +// Internal Trait Utilities +//===----------------------------------------------------------------------===// + +namespace op_definition_impl { +//===----------------------------------------------------------------------===// +// Trait Existence + +/// Returns true if this given Trait ID matches the IDs of any of the provided +/// trait types `Traits`. +template <template <typename T> class... Traits> +static bool hasTrait(TypeID traitID) { + TypeID traitIDs[] = {TypeID::get<Traits>()...}; + for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i) + if (traitIDs[i] == traitID) + return true; + return false; +} + +//===----------------------------------------------------------------------===// +// Trait Folding + +/// Trait to check if T provides a 'foldTrait' method for single result +/// operations. +template <typename T, typename... Args> +using has_single_result_fold_trait = decltype(T::foldTrait( + std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>())); +template <typename T> +using detect_has_single_result_fold_trait = + llvm::is_detected<has_single_result_fold_trait, T>; +/// Trait to check if T provides a general 'foldTrait' method. +template <typename T, typename... Args> +using has_fold_trait = + decltype(T::foldTrait(std::declval<Operation *>(), + std::declval<ArrayRef<Attribute>>(), + std::declval<SmallVectorImpl<OpFoldResult> &>())); +template <typename T> +using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>; +/// Trait to check if T provides any `foldTrait` method. +/// NOTE: This should use std::disjunction when C++17 is available. +template <typename T> +using detect_has_any_fold_trait = + std::conditional_t<bool(detect_has_fold_trait<T>::value), + detect_has_fold_trait<T>, + detect_has_single_result_fold_trait<T>>; + +/// Returns the result of folding a trait that implements a `foldTrait` function +/// that is specialized for operations that have a single result. +template <typename Trait> +static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value, + LogicalResult> +foldTrait(Operation *op, ArrayRef<Attribute> operands, + SmallVectorImpl<OpFoldResult> &results) { + assert(op->hasTrait<OpTrait::OneResult>() && + "expected trait on non single-result operation to implement the " + "general `foldTrait` method"); + // If a previous trait has already been folded and replaced this operation, we + // fail to fold this trait. + if (!results.empty()) + return failure(); + + if (OpFoldResult result = Trait::foldTrait(op, operands)) { + if (result.template dyn_cast<Value>() != op->getResult(0)) + results.push_back(result); + return success(); + } + return failure(); +} +/// Returns the result of folding a trait that implements a generalized +/// `foldTrait` function that is supports any operation type. +template <typename Trait> +static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult> +foldTrait(Operation *op, ArrayRef<Attribute> operands, + SmallVectorImpl<OpFoldResult> &results) { + // If a previous trait has already been folded and replaced this operation, we + // fail to fold this trait. + return results.empty() ? Trait::foldTrait(op, operands, results) : failure(); +} + +/// The internal implementation of `foldTraits` below that returns the result of +/// folding a set of trait types `Ts` that implement a `foldTrait` method. +template <typename... Ts> +static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands, + SmallVectorImpl<OpFoldResult> &results, + std::tuple<Ts...> *) { + bool anyFolded = false; + (void)std::initializer_list<int>{ + (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...}; + return success(anyFolded); +} + +/// Given a tuple type containing a set of traits that contain a `foldTrait` +/// method, return the result of folding the given operation. +template <typename TraitTupleT> +static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult> +foldTraits(Operation *op, ArrayRef<Attribute> operands, + SmallVectorImpl<OpFoldResult> &results) { + return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr); +} +/// A variant of the method above that is specialized when there are no traits +/// that contain a `foldTrait` method. +template <typename TraitTupleT> +static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult> +foldTraits(Operation *op, ArrayRef<Attribute> operands, + SmallVectorImpl<OpFoldResult> &results) { + return failure(); +} + +//===----------------------------------------------------------------------===// +// Trait Properties + +/// Trait to check if T provides a `getTraitProperties` method. +template <typename T, typename... Args> +using has_get_trait_properties = decltype(T::getTraitProperties()); +template <typename T> +using detect_has_get_trait_properties = + llvm::is_detected<has_get_trait_properties, T>; + +/// The internal implementation of `getTraitProperties` below that returns the +/// OR of invoking `getTraitProperties` on all of the provided trait types `Ts`. +template <typename... Ts> +static AbstractOperation::OperationProperties +getTraitPropertiesImpl(std::tuple<Ts...> *) { + AbstractOperation::OperationProperties result = 0; + (void)std::initializer_list<int>{(result |= Ts::getTraitProperties(), 0)...}; + return result; +} + +/// Given a tuple type containing a set of traits that contain a +/// `getTraitProperties` method, return the OR of all of the results of invoking +/// those methods. +template <typename TraitTupleT> +static AbstractOperation::OperationProperties getTraitProperties() { + return getTraitPropertiesImpl((TraitTupleT *)nullptr); +} + +//===----------------------------------------------------------------------===// +// Trait Verification + +/// Trait to check if T provides a `verifyTrait` method. +template <typename T, typename... Args> +using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>())); +template <typename T> +using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>; + +/// The internal implementation of `verifyTraits` below that returns the result +/// of verifying the current operation with all of the provided trait types +/// `Ts`. +template <typename... Ts> +static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) { + LogicalResult result = success(); + (void)std::initializer_list<int>{ + (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...}; + return result; +} + +/// Given a tuple type containing a set of traits that contain a +/// `verifyTrait` method, return the result of verifying the given operation. +template <typename TraitTupleT> +static LogicalResult verifyTraits(Operation *op) { + return verifyTraitsImpl(op, (TraitTupleT *)nullptr); +} +} // namespace op_definition_impl + //===----------------------------------------------------------------------===// // Operation Definition classes //===----------------------------------------------------------------------===// @@ -1314,21 +1356,17 @@ /// argument 'ConcreteType' should be the concrete type by CRTP and the others /// are base classes by the policy pattern. template <typename ConcreteType, template <typename T> class... Traits> -class Op : public OpState, - public Traits<ConcreteType>..., - public FoldingHook<ConcreteType, - llvm::is_one_of<OpTrait::OneResult<ConcreteType>, - Traits<ConcreteType>...>::value> { +class Op : public OpState, public Traits<ConcreteType>... { public: + /// Inherit getOperation from `OpState`. + using OpState::getOperation; + /// Return if this operation contains the provided trait. template <template <typename T> class Trait> static constexpr bool hasTrait() { return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value; } - /// Return the operation that this refers to. - Operation *getOperation() { return OpState::getOperation(); } - /// Create a deep copy of this operation. ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); } @@ -1339,12 +1377,6 @@ return cast<ConcreteType>(getOperation()->cloneWithoutRegions()); } - /// Return the dialect that this refers to. - Dialect *getDialect() { return getOperation()->getDialect(); } - - /// Return the parent Region of this operation. - Region *getParentRegion() { return getOperation()->getParentRegion(); } - /// Return true if this "op class" can match against the specified operation. static bool classof(Operation *op) { if (auto *abstractOp = op->getAbstractOperation()) @@ -1358,56 +1390,6 @@ return false; } - /// This is the hook used by the AsmParser to parse the custom form of this - /// op from an .mlir file. Op implementations should provide a parse method, - /// which returns failure. On success, they should return fill in result with - /// the fields to use. - static ParseResult parseAssembly(OpAsmParser &parser, - OperationState &result) { - return ConcreteType::parse(parser, result); - } - - /// This is the hook used by the AsmPrinter to emit this to the .mlir file. - /// Op implementations should provide a print method. - static void printAssembly(Operation *op, OpAsmPrinter &p) { - auto opPointer = dyn_cast<ConcreteType>(op); - assert(opPointer && - "op's name does not match name of concrete type instantiated with"); - opPointer.print(p); - } - - /// This is the hook that checks whether or not this operation is well - /// formed according to the invariants of its opcode. It delegates to the - /// Traits for their policy implementations, and allows the user to specify - /// their own verify() method. - /// - /// On success this returns false; on failure it emits an error to the - /// diagnostic subsystem and returns true. - static LogicalResult verifyInvariants(Operation *op) { - return failure( - failed(BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op)) || - failed(cast<ConcreteType>(op).verify())); - } - - /// This is the hook that tries to fold the given operation according to its - /// traits. It delegates to the Traits for their policy implementations, and - /// allows the user to specify their own fold() method. - static OpFoldResult foldTraits(Operation *op, ArrayRef<Attribute> operands) { - return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands); - } - - static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results) { - return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands, - results); - } - - // Returns the properties of an operation by combining the properties of the - // traits of the op. - static AbstractOperation::OperationProperties getOperationProperties() { - return BaseProperties<Traits<ConcreteType>...>::getTraitProperties(); - } - /// Expose the type we are instantiated on to template machinery that may want /// to introspect traits on this operation. using ConcreteOpType = ConcreteType; @@ -1430,95 +1412,166 @@ } private: - template <typename... Types> struct BaseVerifier; - - template <typename First, typename... Rest> - struct BaseVerifier<First, Rest...> { - static LogicalResult verifyTrait(Operation *op) { - return failure(failed(First::verifyTrait(op)) || - failed(BaseVerifier<Rest...>::verifyTrait(op))); - } - }; - - template <typename...> struct BaseVerifier { - static LogicalResult verifyTrait(Operation *op) { return success(); } - }; - - template <typename... Types> struct BaseProperties; - - template <typename First, typename... Rest> - struct BaseProperties<First, Rest...> { - static AbstractOperation::OperationProperties getTraitProperties() { - return First::getTraitProperties() | - BaseProperties<Rest...>::getTraitProperties(); - } - }; - - template <typename... Types> - struct BaseFolder; - - template <typename First, typename... Rest> - struct BaseFolder<First, Rest...> { - static OpFoldResult foldTraits(Operation *op, - ArrayRef<Attribute> operands) { - auto result = First::foldTrait(op, operands); - // Failure to fold or in place fold both mean we can continue folding. - if (!result || result.template dyn_cast<Value>() == op->getResult(0)) { - // Only consider the trait fold result if it is a success since - // the operation fold might have been a success originally. - auto resultRemaining = BaseFolder<Rest...>::foldTraits(op, operands); - if (resultRemaining) - result = resultRemaining; - } - - return result; - } - - static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results) { - auto result = First::foldTrait(op, operands, results); - // Failure to fold or in place fold both mean we can continue folding. - if (failed(result) || results.empty()) { - auto resultRemaining = - BaseFolder<Rest...>::foldTraits(op, operands, results); - if (succeeded(resultRemaining)) - result = resultRemaining; - } + /// Trait to check if T provides a 'fold' method for a single result op. + template <typename T, typename... Args> + using has_single_result_fold = + decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>())); + template <typename T> + using detect_has_single_result_fold = + llvm::is_detected<has_single_result_fold, T>; + /// Trait to check if T provides a general 'fold' method. + template <typename T, typename... Args> + using has_fold = decltype( + std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(), + std::declval<SmallVectorImpl<OpFoldResult> &>())); + template <typename T> using detect_has_fold = llvm::is_detected<has_fold, T>; + /// Trait to check if T provides a 'print' method. + template <typename T, typename... Args> + using has_print = + decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>())); + template <typename T> + using detect_has_print = llvm::is_detected<has_print, T>; + /// A tuple type containing the traits that have a `foldTrait` function. + using FoldableTraitsTupleT = typename detail::FilterTypes< + op_definition_impl::detect_has_any_fold_trait, + Traits<ConcreteType>...>::type; + /// A tuple type containing the traits that have a verify function. + using VerifiableTraitsTupleT = + typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait, + Traits<ConcreteType>...>::type; + + /// Returns the properties of this operation by combining the properties + /// defined by the traits. + static AbstractOperation::OperationProperties getOperationProperties() { + return op_definition_impl::getTraitProperties<typename detail::FilterTypes< + op_definition_impl::detect_has_get_trait_properties, + Traits<ConcreteType>...>::type>(); + } - return result; - } - }; + /// Returns an interface map containing the interfaces registered to this + /// operation. + static detail::InterfaceMap getInterfaceMap() { + return detail::InterfaceMap::template get<Traits<ConcreteType>...>(); + } - template <typename...> - struct BaseFolder { - static OpFoldResult foldTraits(Operation *op, - ArrayRef<Attribute> operands) { - return {}; - } - static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results) { - return failure(); + /// Return the internal implementations of each of the AbstractOperation + /// hooks. + /// Implementation of `FoldHookFn` AbstractOperation hook. + static AbstractOperation::FoldHookFn getFoldHookFn() { + return getFoldHookFnImpl<ConcreteType>(); + } + /// The internal implementation of `getFoldHookFn` above that is invoked if + /// the operation is single result and defines a `fold` method. + template <typename ConcreteOpT> + static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>, + Traits<ConcreteOpT>...>::value && + detect_has_single_result_fold<ConcreteOpT>::value, + AbstractOperation::FoldHookFn> + getFoldHookFnImpl() { + return &foldSingleResultHook<ConcreteOpT>; + } + /// The internal implementation of `getFoldHookFn` above that is invoked if + /// the operation is not single result and defines a `fold` method. + template <typename ConcreteOpT> + static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>, + Traits<ConcreteOpT>...>::value && + detect_has_fold<ConcreteOpT>::value, + AbstractOperation::FoldHookFn> + getFoldHookFnImpl() { + return &foldHook<ConcreteOpT>; + } + /// The internal implementation of `getFoldHookFn` above that is invoked if + /// the operation does not define a `fold` method. + template <typename ConcreteOpT> + static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value && + !detect_has_fold<ConcreteOpT>::value, + AbstractOperation::FoldHookFn> + getFoldHookFnImpl() { + // In this case, we only need to fold the traits of the operation. + return &op_definition_impl::foldTraits<FoldableTraitsTupleT>; + } + /// Return the result of folding a single result operation that defines a + /// `fold` method. + template <typename ConcreteOpT> + static LogicalResult + foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands, + SmallVectorImpl<OpFoldResult> &results) { + OpFoldResult result = cast<ConcreteOpT>(op).fold(operands); + + // If the fold failed or was in-place, try to fold the traits of the + // operation. + if (!result || result.template dyn_cast<Value>() == op->getResult(0)) { + if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>( + op, operands, results))) + return success(); + return success(static_cast<bool>(result)); } - }; + results.push_back(result); + return success(); + } + /// Return the result of folding an operation that defines a `fold` method. + template <typename ConcreteOpT> + static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, + SmallVectorImpl<OpFoldResult> &results) { + LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results); - template <typename...> struct BaseProperties { - static AbstractOperation::OperationProperties getTraitProperties() { - return 0; + // If the fold failed or was in-place, try to fold the traits of the + // operation. + if (failed(result) || results.empty()) { + if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>( + op, operands, results))) + return success(); } - }; - - /// Returns true if this operation contains the trait for the given typeID. - static bool hasTrait(TypeID traitID) { - return llvm::is_contained(llvm::makeArrayRef({TypeID::get<Traits>()...}), - traitID); + return result; + } + + /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook. + static AbstractOperation::GetCanonicalizationPatternsFn + getGetCanonicalizationPatternsFn() { + return &ConcreteType::getCanonicalizationPatterns; + } + /// Implementation of `GetHasTraitFn` + static AbstractOperation::HasTraitFn getHasTraitFn() { + return &op_definition_impl::hasTrait<Traits...>; + } + /// Implementation of `ParseAssemblyFn` AbstractOperation hook. + static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() { + return &ConcreteType::parse; + } + /// Implementation of `PrintAssemblyFn` AbstractOperation hook. + static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() { + return getPrintAssemblyFnImpl<ConcreteType>(); + } + /// The internal implementation of `getPrintAssemblyFn` that is invoked when + /// the concrete operation does not define a `print` method. + template <typename ConcreteOpT> + static std::enable_if_t<!detect_has_print<ConcreteOpT>::value, + AbstractOperation::PrintAssemblyFn> + getPrintAssemblyFnImpl() { + return &OpState::print; + } + /// The internal implementation of `getPrintAssemblyFn` that is invoked when + /// the concrete operation defines a `print` method. + template <typename ConcreteOpT> + static std::enable_if_t<detect_has_print<ConcreteOpT>::value, + AbstractOperation::PrintAssemblyFn> + getPrintAssemblyFnImpl() { + return &printAssembly; } - - /// Returns an interface map for the interfaces registered to this operation. - static detail::InterfaceMap getInterfaceMap() { - return detail::InterfaceMap::template get<Traits<ConcreteType>...>(); + static void printAssembly(Operation *op, OpAsmPrinter &p) { + return cast<ConcreteType>(op).print(p); + } + /// Implementation of `VerifyInvariantsFn` AbstractOperation hook. + static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() { + return &verifyInvariants; + } + static LogicalResult verifyInvariants(Operation *op) { + return failure( + failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) || + failed(cast<ConcreteType>(op).verify())); } - /// Allow access to 'hasTrait' and 'getInterfaceMap'. + /// Allow access to internal implementation methods. friend AbstractOperation; }; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -80,6 +80,15 @@ public: using OperationProperties = uint32_t; + using GetCanonicalizationPatternsFn = void (*)(OwningRewritePatternList &, + MLIRContext *); + using FoldHookFn = LogicalResult (*)(Operation *, ArrayRef<Attribute>, + SmallVectorImpl<OpFoldResult> &); + using HasTraitFn = bool (*)(TypeID); + using ParseAssemblyFn = ParseResult (*)(OpAsmParser &, OperationState &); + using PrintAssemblyFn = void (*)(Operation *, OpAsmPrinter &); + using VerifyInvariantsFn = LogicalResult (*)(Operation *); + /// This is the name of the operation. const Identifier name; @@ -90,15 +99,19 @@ TypeID typeID; /// Use the specified object to parse this ops custom assembly format. - ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result); + ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const; /// This hook implements the AsmPrinter for this operation. - void (&printAssembly)(Operation *op, OpAsmPrinter &p); + void printAssembly(Operation *op, OpAsmPrinter &p) const { + return printAssemblyFn(op, p); + } /// This hook implements the verifier for this operation. It should emits an /// error message and returns failure if a problem is detected, or returns /// success if everything is ok. - LogicalResult (&verifyInvariants)(Operation *op); + LogicalResult verifyInvariants(Operation *op) const { + return verifyInvariantsFn(op); + } /// This hook implements a generalized folder for this operation. Operations /// can implement this to provide simplifications rules that are applied by @@ -119,13 +132,17 @@ /// This allows expression of some simple in-place canonicalizations (e.g. /// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as /// generalized constant folding. - LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results); + LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, + SmallVectorImpl<OpFoldResult> &results) const { + return foldHookFn(op, operands, results); + } /// This hook returns any canonicalization pattern rewrites that the operation /// supports, for use by the canonicalization pass. - void (&getCanonicalizationPatterns)(OwningRewritePatternList &results, - MLIRContext *context); + void getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) const { + return getCanonicalizationPatternsFn(results, context); + } /// Returns whether the operation has a particular property. bool hasProperty(OperationProperty property) const { @@ -141,7 +158,7 @@ /// Returns true if the operation has a particular trait. template <template <typename T> class Trait> bool hasTrait() const { - return hasRawTrait(TypeID::get<Trait>()); + return hasTraitFn(TypeID::get<Trait>()); } /// Look up the specified operation in the specified MLIRContext and return a @@ -151,26 +168,30 @@ /// This constructor is used by Dialect objects when they register the list of /// operations they contain. - template <typename T> static AbstractOperation get(Dialect &dialect) { - return AbstractOperation( - T::getOperationName(), dialect, T::getOperationProperties(), - TypeID::get<T>(), T::parseAssembly, T::printAssembly, - T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns, - T::getInterfaceMap(), T::hasTrait); + template <typename T> static void insert(Dialect &dialect) { + insert(T::getOperationName(), dialect, T::getOperationProperties(), + TypeID::get<T>(), T::getParseAssemblyFn(), T::getPrintAssemblyFn(), + T::getVerifyInvariantsFn(), T::getFoldHookFn(), + T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(), + T::getHasTraitFn()); } private: - AbstractOperation( - StringRef name, Dialect &dialect, OperationProperties opProperties, - TypeID typeID, - ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result), - void (&printAssembly)(Operation *op, OpAsmPrinter &p), - LogicalResult (&verifyInvariants)(Operation *op), - LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results), - void (&getCanonicalizationPatterns)(OwningRewritePatternList &results, - MLIRContext *context), - detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID)); + static void insert(StringRef name, Dialect &dialect, + OperationProperties opProperties, TypeID typeID, + ParseAssemblyFn parseAssembly, + PrintAssemblyFn printAssembly, + VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook, + GetCanonicalizationPatternsFn getCanonicalizationPatterns, + detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait); + + AbstractOperation(StringRef name, Dialect &dialect, + OperationProperties opProperties, TypeID typeID, + ParseAssemblyFn parseAssembly, + PrintAssemblyFn printAssembly, + VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook, + GetCanonicalizationPatternsFn getCanonicalizationPatterns, + detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait); /// The properties of the operation. const OperationProperties opProperties; @@ -178,9 +199,13 @@ /// A map of interfaces that were registered to this operation. detail::InterfaceMap interfaceMap; - /// This hook returns if the operation contains the trait corresponding - /// to the given TypeID. - bool (&hasRawTrait)(TypeID traitID); + /// Internal callback hooks provided by the op implementation. + FoldHookFn foldHookFn; + GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; + HasTraitFn hasTraitFn; + ParseAssemblyFn parseAssemblyFn; + PrintAssemblyFn printAssemblyFn; + VerifyInvariantsFn verifyInvariantsFn; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -621,22 +621,6 @@ return impl->registeredOperations.count(name); } -void Dialect::addOperation(AbstractOperation opInfo) { - assert((getNamespace().empty() || opInfo.dialect.name == getNamespace()) && - "op name doesn't start with dialect namespace"); - assert(&opInfo.dialect == this && "Dialect object mismatch"); - auto &impl = context->getImpl(); - assert(impl.multiThreadedExecutionContext == 0 && - "Registering a new operation kind while in a multi-threaded execution " - "context"); - StringRef opName = opInfo.name; - if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) { - llvm::errs() << "error: operation named '" << opInfo.name - << "' is already registered.\n"; - abort(); - } -} - void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { auto &impl = context->getImpl(); assert(impl.multiThreadedExecutionContext == 0 && @@ -661,6 +645,10 @@ llvm::report_fatal_error("Dialect Attribute already registered."); } +//===----------------------------------------------------------------------===// +// AbstractAttribute +//===----------------------------------------------------------------------===// + /// Get the dialect that registered the attribute with the provided typeid. const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, MLIRContext *context) { @@ -672,8 +660,17 @@ return *it->second; } +//===----------------------------------------------------------------------===// +// AbstractOperation +//===----------------------------------------------------------------------===// + +ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser, + OperationState &result) const { + return parseAssemblyFn(parser, result); +} + /// Look up the specified operation in the operation set and return a pointer -/// to it if present. Otherwise, return a null pointer. +/// to it if present. Otherwise, return a null pointer. const AbstractOperation *AbstractOperation::lookup(StringRef opName, MLIRContext *context) { auto &impl = context->getImpl(); @@ -683,26 +680,45 @@ return nullptr; } +void AbstractOperation::insert( + StringRef name, Dialect &dialect, OperationProperties opProperties, + TypeID typeID, ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly, + VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook, + GetCanonicalizationPatternsFn getCanonicalizationPatterns, + detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) { + AbstractOperation opInfo(name, dialect, opProperties, typeID, parseAssembly, + printAssembly, verifyInvariants, foldHook, + getCanonicalizationPatterns, std::move(interfaceMap), + hasTrait); + + auto &impl = dialect.getContext()->getImpl(); + assert(impl.multiThreadedExecutionContext == 0 && + "Registering a new operation kind while in a multi-threaded execution " + "context"); + if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) { + llvm::errs() << "error: operation named '" << name + << "' is already registered.\n"; + abort(); + } +} + AbstractOperation::AbstractOperation( StringRef name, Dialect &dialect, OperationProperties opProperties, - TypeID typeID, - ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result), - void (&printAssembly)(Operation *op, OpAsmPrinter &p), - LogicalResult (&verifyInvariants)(Operation *op), - LogicalResult (&foldHook)(Operation *op, ArrayRef<Attribute> operands, - SmallVectorImpl<OpFoldResult> &results), - void (&getCanonicalizationPatterns)(OwningRewritePatternList &results, - MLIRContext *context), - detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID)) + TypeID typeID, ParseAssemblyFn parseAssembly, PrintAssemblyFn printAssembly, + VerifyInvariantsFn verifyInvariants, FoldHookFn foldHook, + GetCanonicalizationPatternsFn getCanonicalizationPatterns, + detail::InterfaceMap &&interfaceMap, HasTraitFn hasTrait) : name(Identifier::get(name, dialect.getContext())), dialect(dialect), - typeID(typeID), parseAssembly(parseAssembly), - printAssembly(printAssembly), verifyInvariants(verifyInvariants), - foldHook(foldHook), - getCanonicalizationPatterns(getCanonicalizationPatterns), - opProperties(opProperties), interfaceMap(std::move(interfaceMap)), - hasRawTrait(hasTrait) {} - -/// Get the dialect that registered the type with the provided typeid. + typeID(typeID), opProperties(opProperties), + interfaceMap(std::move(interfaceMap)), foldHookFn(foldHook), + getCanonicalizationPatternsFn(getCanonicalizationPatterns), + hasTraitFn(hasTrait), parseAssemblyFn(parseAssembly), + printAssemblyFn(printAssembly), verifyInvariantsFn(verifyInvariants) {} + +//===----------------------------------------------------------------------===// +// AbstractType +//===----------------------------------------------------------------------===// + const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { auto &impl = context->getImpl(); auto it = impl.registeredTypes.find(typeID); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -649,7 +649,7 @@ } // The fallback for the printer is to print in the generic assembly form. -void OpState::print(OpAsmPrinter &p) { p.printGenericOp(getOperation()); } +void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } /// Emit an error about fatal conditions with this operation, reporting up to /// any diagnostic handlers that may be listening.