diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2074,11 +2074,15 @@ // OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in // C++. The purpose to wrap around C++ symbol string with this class is to make // interfaces specified for ops in TableGen less alien and more integrated. -class OpInterfaceTrait +class OpInterfaceTrait traits = []> : InterfaceTrait, OpTrait { // Specify the body of the verification function. `$_op` will be replaced with // the operation being verified. code verify = verifyBody; + + // Specify the list of trait verifiers that need to be run before the verifier + // of this OpInterfaceTrait. + list dependentTraits = traits; } // This class represents a single, optionally static, interface method. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -281,8 +281,12 @@ /// Helper class for implementing traits. Clients are not expected to interact /// with this directly, so its members are all protected. -template class TraitType> +template class TraitType, + typename... DependentList> class TraitBase { +public: + using dependentTupleT = std::tuple; + protected: /// Return the ultimate Operation being worked on. Operation *getOperation() { @@ -1554,6 +1558,116 @@ static LogicalResult verifyTraits(Operation *op) { return verifyTraitsImpl(op, (TraitTupleT *)nullptr); } + +/// A trait verifier may specify a set of trait verifiers that need to +/// be run before itself. The dependent traits are labeled in the trailing +/// template arguments of TraitBase. The declaration order of traits in an Op +/// definition implies the execution order of trait. This class is used to +/// verify if there's any dependency violation in the declaration order. +/// For example, suppose we have two traits and an op defined as below, +// +/// template +/// class TraitA : public TraitBase {}; +/// +/// // TraitB expects TraitsA will be verified before itself. +/// template +/// class TraitB : public TraitBase> {}; +/// +/// class FooOp : Op {}; +/// +/// Suppose we have the verifiers for both traits. While verifying FooOp, the +/// verifier execution order is TraitB comes first then TraitA. Because TraitB +/// claims that it wants TraitA to be verified first, it'll trigger a +/// static_assert on `TraitsOrderViolationFailure::value` and we can see the +/// template arguments of `TraitsOrderViolationFailure` from the compilation +/// message to know what's the order violation. +/// TODO: Use Concepts in C++20 will make the error message more concise. +template +struct VerifyTraitsOrder { +private: + /// Remove the first element of tuple. + template + struct RemoveFront; + template + struct RemoveFront> { + using type = std::tuple; + }; + + /// Reverse the given tuple. + template + struct ReverseTuple; + template <> + struct ReverseTuple> { + using type = std::tuple<>; + }; + template + struct ReverseTuple> { + using type = decltype(std::tuple_cat( + std::declval>::type>(), + std::tuple())); + }; + + /// This is used to highlight the arguments of template instantiation. + template + struct TraitsOrderViolationFailure { + static constexpr bool value = false; + }; + + template + struct OrderViolationBetween : std::false_type { + static_assert( + TraitsOrderViolationFailure::value, + "There is a traits order violation happens, please check the template " + "arguments of TraitsOrderViolationFailure to see what are the traits " + "and check the dependency of those two traits"); + }; + + /// If `DependentTrait` is in `TraitsBefore`, it means the trait will have + /// been verified. + template + struct HasVerified + : std::conditional< + llvm::is_one_of::value, + std::true_type, + OrderViolationBetween>::type {}; + + template + struct CheckEachDependent + : std::conditional< + HasVerified, + TraitsBefore...>::value, + typename CheckEachDependent< + Trait, typename RemoveFront::type, + TraitsBefore...>::type, + std::false_type>::type {}; + template + struct CheckEachDependent, TraitsBefore...> + : std::true_type {}; + + template + struct VerifyImpl + : std::conditional< + CheckEachDependent::value, + typename VerifyImpl::type, std::false_type>::type {}; + template + struct VerifyImpl + : std::conditional::value == + 0, + std::true_type, std::false_type>::type {}; + +public: + template + struct Verify; + template + struct Verify> : VerifyImpl::type {}; + template <> + struct Verify> : std::true_type {}; + + using TraitTupleTy = std::tuple; + using ReverseTraitTupleTy = typename ReverseTuple::type; + using type = typename Verify::type; +}; } // namespace op_definition_impl //===----------------------------------------------------------------------===// @@ -1823,14 +1937,16 @@ /// This class represents the base of an operation interface. See the definition /// of `detail::Interface` for requirements on the `Traits` type. -template +template class OpInterface : public detail::Interface, OpTrait::TraitBase> { + Op, OpTrait::TraitBase, + dependentTraits...> { public: using Base = OpInterface; - using InterfaceBase = detail::Interface, OpTrait::TraitBase>; + using InterfaceBase = + detail::Interface, + OpTrait::TraitBase, dependentTraits...>; /// Inherit the base class constructor. using InterfaceBase::InterfaceBase; diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -65,9 +65,10 @@ /// to use for the interface trait that will be attached to each /// instance of `ValueT` that implements this interface. /// -template class> class BaseTrait> +template < + typename ConcreteType, typename ValueT, typename Traits, typename BaseType, + template class, typename...> class BaseTrait, + typename... dependentTraits> class Interface : public BaseType { public: using Concept = typename Traits::Concept; @@ -81,7 +82,7 @@ /// This is a special trait that registers a given interface with an object. template - struct Trait : public BaseTrait { + struct Trait : public BaseTrait { using ModelT = Model; /// Define an accessor for the ID of this interface. diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -10,6 +10,7 @@ #define MLIR_TABLEGEN_INTERFACES_H_ #include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Trait.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -94,6 +95,8 @@ // Return the verify method body if it has one. llvm::Optional getVerify() const; + llvm::ArrayRef getDependentTraits() const; + // Returns the Tablegen definition this interface was constructed from. const llvm::Record &getDef() const { return *def; } @@ -101,6 +104,8 @@ // The TableGen definition of this interface. const llvm::Record *def; + llvm::SmallVector dependentTraits; + // The methods of this interface. SmallVector methods; }; diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp --- a/mlir/lib/TableGen/Class.cpp +++ b/mlir/lib/TableGen/Class.cpp @@ -8,6 +8,7 @@ #include "mlir/TableGen/Class.h" +#include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/Twine.h" @@ -293,4 +294,16 @@ } os << "};\n"; + + // Verify traits order. + // FIXME: Move to cpp file and try to only instantiate rather than have a + // variable declaration. + os << "namespace {\n"; + os << "::mlir::op_definition_impl::VerifyTraitsOrder<"; + llvm::interleave( + traits.begin(), traits.end(), + [&](StringRef str) { os << str << "<" << className << ">"; }, + [&]() { os << ", "; }); + os << "> " << className << "TraitsVerification;\n"; + os << "} // namespace\n"; } diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -77,6 +77,31 @@ auto *listInit = dyn_cast(def->getValueInit("methods")); for (llvm::Init *init : listInit->getValues()) methods.emplace_back(cast(init)->getDef()); + + if (!isa(this)) + return; + + if (auto *traitList = dyn_cast( + def->getValueAsListInit("dependentTraits"))) { + std::function insert = + [&](llvm::ListInit *traitList) { + for (auto *traitInit : *traitList) { + auto *def = cast(traitInit)->getDef(); + if (def->isSubClassOf("OpTraitList")) { + insert(def->getValueAsListInit("traits")); + continue; + } + // This is supposed to only have few, use linear search is enough. + if (llvm::none_of(dependentTraits, [&](Trait &trait) { + auto def = cast(traitInit)->getDef(); + return def == &trait.getDef(); + })) + dependentTraits.push_back(Trait::create(traitInit)); + } + }; + + insert(traitList); + } } // Return the name of this interface. @@ -119,6 +144,10 @@ return value.empty() ? llvm::Optional() : value; } +llvm::ArrayRef Interface::getDependentTraits() const { + return dependentTraits; +} + //===----------------------------------------------------------------------===// // AttrInterface //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -88,7 +88,7 @@ void emitModelDecl(Interface &interface); void emitModelMethodsDef(Interface &interface); void emitTraitDecl(Interface &interface, StringRef interfaceName, - StringRef interfaceTraitsName); + StringRef interfaceTraitsName, StringRef dependentTraits); void emitInterfaceDecl(Interface interface); /// The set of interface records to emit. @@ -376,12 +376,13 @@ void InterfaceGenerator::emitTraitDecl(Interface &interface, StringRef interfaceName, - StringRef interfaceTraitsName) { + StringRef interfaceTraitsName, + StringRef dependentTraits) { os << llvm::formatv(" template \n" " struct {0}Trait : public ::mlir::{2}<{0}," - " detail::{1}>::Trait<{3}> {{\n", + " detail::{1}{4}>::Trait<{3}> {{\n", interfaceName, interfaceTraitsName, interfaceBaseType, - valueTemplate); + valueTemplate, dependentTraits); // Insert the default implementation for any methods. bool isOpInterface = isa(interface); @@ -445,12 +446,25 @@ os << "\n} // end namespace detail\n"; + std::string dependentTraits; + { + llvm::raw_string_ostream os(dependentTraits); + + for (auto &trait : interface.getDependentTraits()) { + if (const auto *nativeTrait = dyn_cast(&trait)) + os << ", " << nativeTrait->getFullyQualifiedTraitName(); + else if (const auto *interfaceTrait = + dyn_cast(&trait)) + os << ", " << interfaceTrait->getFullyQualifiedTraitName(); + } + } + // Emit the main interface class declaration. - os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" + os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}{4}> {\n" "public:\n" " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", interfaceName, interfaceName, interfaceTraitsName, - interfaceBaseType); + interfaceBaseType, dependentTraits); // Emit a utility wrapper trait class. os << llvm::formatv(" template \n" @@ -473,7 +487,7 @@ os << "};\n"; os << "namespace detail {\n"; - emitTraitDecl(interface, interfaceName, interfaceTraitsName); + emitTraitDecl(interface, interfaceName, interfaceTraitsName, dependentTraits); os << "}// namespace detail\n"; emitModelMethodsDef(interface);