diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1540,36 +1540,24 @@ // fail to fold this trait. return results.empty() ? Trait::foldTrait(op, operands, results) : failure(); } +template +static inline std::enable_if_t::value, + LogicalResult> +foldTrait(Operation *, ArrayRef, SmallVectorImpl &) { + return failure(); +} -/// The internal implementation of `foldTraits` below that returns the result of -/// folding a set of trait types `Ts` that implement a `foldTrait` method. +/// Given a tuple type containing a set of traits, return the result of folding +/// the given operation. template -static LogicalResult foldTraitsImpl(Operation *op, ArrayRef operands, - SmallVectorImpl &results, - std::tuple *) { +static LogicalResult foldTraits(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { bool anyFolded = false; (void)std::initializer_list{ (anyFolded |= succeeded(foldTrait(op, operands, results)), 0)...}; return success(anyFolded); } -/// Given a tuple type containing a set of traits that contain a `foldTrait` -/// method, return the result of folding the given operation. -template -static std::enable_if_t::value != 0, LogicalResult> -foldTraits(Operation *op, ArrayRef operands, - SmallVectorImpl &results) { - return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr); -} -/// A variant of the method above that is specialized when there are no traits -/// that contain a `foldTrait` method. -template -static std::enable_if_t::value == 0, LogicalResult> -foldTraits(Operation *op, ArrayRef operands, - SmallVectorImpl &results) { - return failure(); -} - //===----------------------------------------------------------------------===// // Trait Verification @@ -1587,44 +1575,51 @@ using detect_has_verify_region_trait = llvm::is_detected; -/// The internal implementation of `verifyTraits` below that returns the result -/// of verifying the current operation with all of the provided trait types -/// `Ts`. +/// Verify the given trait if it provides a verifier. +template +std::enable_if_t::value, LogicalResult> +verifyTrait(Operation *op) { + return T::verifyTrait(op); +} +template +inline std::enable_if_t::value, LogicalResult> +verifyTrait(Operation *) { + return success(); +} + +/// Given a set of traits, return the result of verifying the given operation. template -static LogicalResult verifyTraitsImpl(Operation *op, std::tuple *) { +LogicalResult verifyTraits(Operation *op) { LogicalResult result = success(); (void)std::initializer_list{ - (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...}; + (result = succeeded(result) ? verifyTrait(op) : failure(), 0)...}; return result; } -/// Given a tuple type containing a set of traits that contain a -/// `verifyTrait` method, return the result of verifying the given operation. -template -static LogicalResult verifyTraits(Operation *op) { - return verifyTraitsImpl(op, (TraitTupleT *)nullptr); +/// Verify the given trait if it provides a region verifier. +template +std::enable_if_t::value, LogicalResult> +verifyRegionTrait(Operation *op) { + return T::verifyRegionTrait(op); +} +template +inline std::enable_if_t::value, + LogicalResult> +verifyRegionTrait(Operation *) { + return success(); } -/// The internal implementation of `verifyRegionTraits` below that returns the -/// result of verifying the current operation with all of the provided trait -/// types `Ts`. +/// Given a set of traits, return the result of verifying the regions of the +/// given operation. template -static LogicalResult verifyRegionTraitsImpl(Operation *op, - std::tuple *) { +LogicalResult verifyRegionTraits(Operation *op) { (void)op; LogicalResult result = success(); (void)std::initializer_list{ - (result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(), + (result = succeeded(result) ? verifyRegionTrait(op) : failure(), 0)...}; return result; } - -/// Given a tuple type containing a set of traits that contain a -/// `verifyTrait` method, return the result of verifying the given operation. -template -static LogicalResult verifyRegionTraits(Operation *op) { - return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr); -} } // namespace op_definition_impl //===----------------------------------------------------------------------===// @@ -1733,18 +1728,6 @@ decltype(std::declval().print(std::declval())); template using detect_has_print = llvm::is_detected; - /// A tuple type containing the traits that have a `foldTrait` function. - using FoldableTraitsTupleT = typename detail::FilterTypes< - op_definition_impl::detect_has_any_fold_trait, - Traits...>::type; - /// A tuple type containing the traits that have a verify function. - using VerifiableTraitsTupleT = - typename detail::FilterTypes...>::type; - /// A tuple type containing the region traits that have a verify function. - using VerifiableRegionTraitsTupleT = typename detail::FilterTypes< - op_definition_impl::detect_has_verify_region_trait, - Traits...>::type; /// Returns an interface map containing the interfaces registered to this /// operation. @@ -1794,8 +1777,8 @@ return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { // In this case, we only need to fold the traits of the operation. - return op_definition_impl::foldTraits(op, operands, - results); + return op_definition_impl::foldTraits...>( + op, operands, results); }; } /// Return the result of folding a single result operation that defines a @@ -1809,7 +1792,7 @@ // If the fold failed or was in-place, try to fold the traits of the // operation. if (!result || result.template dyn_cast() == op->getResult(0)) { - if (succeeded(op_definition_impl::foldTraits( + if (succeeded(op_definition_impl::foldTraits...>( op, operands, results))) return success(); return success(static_cast(result)); @@ -1826,7 +1809,7 @@ // If the fold failed or was in-place, try to fold the traits of the // operation. if (failed(result) || results.empty()) { - if (succeeded(op_definition_impl::foldTraits( + if (succeeded(op_definition_impl::foldTraits...>( op, operands, results))) return success(); } @@ -1879,7 +1862,7 @@ static_assert(hasNoDataMembers(), "Op class shouldn't define new data members"); return failure( - failed(op_definition_impl::verifyTraits(op)) || + failed(op_definition_impl::verifyTraits...>(op)) || failed(cast(op).verify())); } static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { @@ -1889,9 +1872,10 @@ static LogicalResult verifyRegionInvariants(Operation *op) { static_assert(hasNoDataMembers(), "Op class shouldn't define new data members"); - return failure(failed(op_definition_impl::verifyRegionTraits< - VerifiableRegionTraitsTupleT>(op)) || - failed(cast(op).verifyRegions())); + return failure( + failed(op_definition_impl::verifyRegionTraits...>( + op)) || + failed(cast(op).verifyRegions())); } static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() { return static_cast(&verifyRegionInvariants); diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -125,23 +125,17 @@ // InterfaceMap //===----------------------------------------------------------------------===// -/// Utility to filter a given sequence of types base upon a predicate. -template -struct FilterTypeT { - template - using type = std::tuple; -}; -template <> -struct FilterTypeT { - template - using type = std::tuple<>; -}; -template