diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2076,11 +2076,15 @@ // OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in // C++. The purpose to wrap around C++ symbol string with this class is to make // interfaces specified for ops in TableGen less alien and more integrated. -class OpInterfaceTrait +class OpInterfaceTrait traits = []> : InterfaceTrait, OpTrait { // Specify the body of the verification function. `$_op` will be replaced with // the operation being verified. code verify = verifyBody; + + // Specify the list of trait verifiers that need to be run before the verifier + // of this OpInterfaceTrait. + list dependentTraits = traits; } // This class represents a single, optionally static, interface method. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -287,8 +287,12 @@ /// Helper class for implementing traits. Clients are not expected to interact /// with this directly, so its members are all protected. -template class TraitType> +template class TraitType, + typename... DependentList> class TraitBase { +public: + using dependentTupleT = std::tuple; + protected: /// Return the ultimate Operation being worked on. Operation *getOperation() { @@ -1560,6 +1564,127 @@ static LogicalResult verifyTraits(Operation *op) { return verifyTraitsImpl(op, (TraitTupleT *)nullptr); } + +/// A trait verifier may specify a set of trait verifiers that need to +/// be run before itself. The dependent traits are labeled in the trailing +/// template arguments of TraitBase. The declaration order of traits in an Op +/// definition implies the execution order of trait. This class is used to +/// verify if there's any dependency violation in the declaration order. +/// For example, suppose we have two traits and an op defined as below, +// +/// template +/// class TraitA : public TraitBase {}; +/// +/// // TraitB expects TraitsA will be verified before itself. +/// template +/// class TraitB : public TraitBase> {}; +/// +/// class FooOp : Op {}; +/// +/// Suppose we have the verifiers for both traits. While verifying FooOp, the +/// verifier execution order is TraitB comes first then TraitA. Because TraitB +/// claims that it wants TraitA to be verified first, it'll trigger a +/// static_assert on `TraitsOrderViolationFailure::value` and we can see the +/// template arguments of `TraitsOrderViolationFailure` from the compilation +/// message to know what's the order violation. +/// TODO: Use Concepts in C++20 will make the error message more concise. +template +struct VerifyTraitsOrder { +private: + /// Remove the first element of tuple. + template + struct RemoveFront; + template + struct RemoveFront> { + using type = std::tuple; + }; + + /// Reverse the given tuple. + template + struct ReverseTuple; + template <> + struct ReverseTuple> { + using type = std::tuple<>; + }; + template + struct ReverseTuple> { + using type = decltype(std::tuple_cat( + std::declval>::type>(), + std::tuple())); + }; + + /// This is a helper to abbriviate duplicate template argument. + template + struct TypeValueConditional : std::conditional { + }; + + /// This is used to highlight the arguments of template instantiation. + template + struct TraitsOrderViolationFailure : std::false_type { + using LType = L; + /// The `RType` may be a std::tuple type, which means there are several + /// traits have order violation with `LType`. + using RType = R; + }; + + template + struct OrderViolationBetween : TraitsOrderViolationFailure { + static_assert( + !StaticAssert || TraitsOrderViolationFailure::value, + "Traits order verification failed, please check the template arguments " + "of TraitsOrderViolationFailure to see what are the traits " + "and check the dependency of those two traits"); + }; + + /// If `DependentTrait` is in `TraitsBefore`, it means the trait will have + /// been verified. + template + struct HasVerified + : std::conditional_t< + llvm::is_one_of::value, + std::true_type, OrderViolationBetween> {}; + + /// `DependentTraits` is a std::tuple contains all the dependent trait types + /// of `Trait`. `TraitsBefore` are the traits declared before `Trait`. + template + struct CheckEachDependent + : TypeValueConditional< + HasVerified::type, + TraitsBefore...>, + CheckEachDependent::type, + TraitsBefore...>>::type {}; + template + struct CheckEachDependent, TraitsBefore...> + : std::true_type {}; + + template + struct VerifyImpl + : TypeValueConditional< + CheckEachDependent, + VerifyImpl>::type {}; + template + struct VerifyImpl + : std::conditional_t< + std::tuple_size::value == 0, + std::true_type, + OrderViolationBetween> {}; + +public: + template + struct Verify; + template + struct Verify> : VerifyImpl {}; + template <> + struct Verify> : std::true_type {}; + + using TraitTupleTy = std::tuple; + using ReverseTraitTupleTy = typename ReverseTuple::type; + using result = Verify; +}; } // namespace op_definition_impl //===----------------------------------------------------------------------===// @@ -1829,14 +1954,16 @@ /// This class represents the base of an operation interface. See the definition /// of `detail::Interface` for requirements on the `Traits` type. -template +template class OpInterface : public detail::Interface, OpTrait::TraitBase> { + Op, OpTrait::TraitBase, + dependentTraits...> { public: using Base = OpInterface; - using InterfaceBase = detail::Interface, OpTrait::TraitBase>; + using InterfaceBase = + detail::Interface, + OpTrait::TraitBase, dependentTraits...>; /// Inherit the base class constructor. using InterfaceBase::InterfaceBase; diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -65,9 +65,10 @@ /// to use for the interface trait that will be attached to each /// instance of `ValueT` that implements this interface. /// -template class> class BaseTrait> +template < + typename ConcreteType, typename ValueT, typename Traits, typename BaseType, + template class, typename...> class BaseTrait, + typename... dependentTraits> class Interface : public BaseType { public: using Concept = typename Traits::Concept; @@ -81,7 +82,7 @@ /// This is a special trait that registers a given interface with an object. template - struct Trait : public BaseTrait { + struct Trait : public BaseTrait { using ModelT = Model; /// Define an accessor for the ID of this interface. diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -458,6 +458,10 @@ templateParams.insert(std::begin(container), std::end(container)); } + ArrayRef getTemplateParams() const { + return templateParams.getArrayRef(); + } + /// Write the parent class declaration. void writeTo(raw_indented_ostream &os) const; @@ -679,6 +683,11 @@ raw_indented_ostream os(rawOs); writeDefTo(os); } + void writeTraitOrderVerification(raw_ostream &rawOs, + bool staticVerifyTraitOrder) const { + raw_indented_ostream os(rawOs); + writeTraitOrderVerification(os, staticVerifyTraitOrder); + } /// Write the declaration of this class, all declarations, and definitions of /// inline functions. @@ -686,6 +695,10 @@ /// Write the definitions of thiss class's out-of-line constructors and /// methods. void writeDefTo(raw_indented_ostream &os) const; + /// Write the explicit instantiation of VerifyTraitsOrder which will verify + /// the order of traits. + void writeTraitOrderVerification(raw_indented_ostream &os, + bool staticVerifyTraitOrder) const; /// Add a declaration. The declaration is appended directly to the list of /// class declarations. diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -10,6 +10,7 @@ #define MLIR_TABLEGEN_INTERFACES_H_ #include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Trait.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -94,6 +95,8 @@ // Return the verify method body if it has one. llvm::Optional getVerify() const; + llvm::ArrayRef getDependentTraits() const; + // Returns the Tablegen definition this interface was constructed from. const llvm::Record &getDef() const { return *def; } @@ -101,6 +104,8 @@ // The TableGen definition of this interface. const llvm::Record *def; + llvm::SmallVector dependentTraits; + // The methods of this interface. SmallVector methods; }; diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp --- a/mlir/lib/TableGen/Class.cpp +++ b/mlir/lib/TableGen/Class.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Class.h" + +#include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/Twine.h" @@ -287,6 +289,22 @@ decl->writeDeclTo(os); } +void Class::writeTraitOrderVerification(raw_indented_ostream &os, + bool staticVerifyTraitOrder) const { + auto anonymousScope = os.scope("namespace {\n", "}\n", /*indent=*/true); + os << "::mlir::op_definition_impl::VerifyTraitsOrder"; + { + auto classCope = os.scope("<", ">::result "); + os << (staticVerifyTraitOrder ? "true, " : "false, "); + llvm::interleaveComma(parents, os, [&](auto &parent) { + llvm::interleaveComma( + parent.getTemplateParams().drop_front(), os, + [&](StringRef param) { os << param << "<" << className << ">"; }); + }); + } + os << className << "Verifier;\n"; +} + void Class::writeDefTo(raw_indented_ostream &os) const { // Print all the definitions. for (auto &decl : declarations) diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -77,6 +77,32 @@ auto *listInit = dyn_cast(def->getValueInit("methods")); for (llvm::Init *init : listInit->getValues()) methods.emplace_back(cast(init)->getDef()); + + if (!isa(this)) + return; + + if (auto *traitList = dyn_cast( + def->getValueAsListInit("dependentTraits"))) { + std::function insert = + [&](llvm::ListInit *traitList) { + for (auto *traitInit : *traitList) { + auto *def = cast(traitInit)->getDef(); + if (def->isSubClassOf("OpTraitList")) { + insert(def->getValueAsListInit("traits")); + continue; + } + // This is supposed to only have few of them, use linear search is + // enough. + if (llvm::none_of(dependentTraits, [&](Trait &trait) { + auto def = cast(traitInit)->getDef(); + return def == &trait.getDef(); + })) + dependentTraits.push_back(Trait::create(traitInit)); + } + }; + + insert(traitList); + } } // Return the name of this interface. @@ -119,6 +145,10 @@ return value.empty() ? llvm::Optional() : value; } +llvm::ArrayRef Interface::getDependentTraits() const { + return dependentTraits; +} + //===----------------------------------------------------------------------===// // AttrInterface //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/check-op-trait-order.mlir b/mlir/test/IR/check-op-trait-order.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/check-op-trait-order.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt -test-trait-order + +// It will verify the invalid registered ops have marked the trait order +// violation properly. +func @test() { +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -25,7 +25,7 @@ set(LLVM_TARGET_DEFINITIONS TestOps.td) -mlir_tablegen(TestOps.h.inc -gen-op-decls) +mlir_tablegen(TestOps.h.inc -gen-op-decls-with-runtime-trait-order-check) mlir_tablegen(TestOps.cpp.inc -gen-op-defs) mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test) mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test) @@ -42,6 +42,7 @@ TestDialect.cpp TestInterfaces.cpp TestPatterns.cpp + TestTraitOrder.cpp TestTraits.cpp TestTypes.cpp diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -650,6 +650,12 @@ }]; } +def TraitA : NativeOpTrait<"TraitA"> {} +// TraitB depends on TraitA to be verified first. +def TraitB : NativeOpTrait<"TraitB"> {} + +def OpTraitDependencyOp : TEST_Op<"op-trait-dependency", [TraitB, TraitA]> {} +def MissingDepTraitOp : TEST_Op<"missing-dep-trait", [TraitB]> {} //===----------------------------------------------------------------------===// // Test Locations diff --git a/mlir/test/lib/Dialect/Test/TestTraitOrder.cpp b/mlir/test/lib/Dialect/Test/TestTraitOrder.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTraitOrder.cpp @@ -0,0 +1,50 @@ +//===- TestTraitOrder.cpp - MLIR Dialect for Testing +//-------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestTraits.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace test; + +namespace { +struct TestTraitOrder : public PassWrapper { + StringRef getArgument() const final { return "test-trait-order"; } + StringRef getDescription() const final { + return "Check the dependency of op traits"; + } + void runOnFunction() override { + assert(decltype(OpTraitDependencyOpVerifier)::value == false); + assert( + (llvm::is_one_of, + OpTrait::TraitB>::value == true)); + assert( + (llvm::is_one_of, + OpTrait::TraitB>::value == true)); + + assert(decltype(MissingDepTraitOpVerifier)::value == false); + assert( + (llvm::is_one_of, + OpTrait::TraitB>::value == true)); + assert( + (llvm::is_one_of, + OpTrait::TraitB>::value == true)); + } +}; +} // namespace + +namespace mlir { +void registerTestTraitOrderPass() { PassRegistration(); } +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestTraits.h b/mlir/test/lib/Dialect/Test/TestTraits.h --- a/mlir/test/lib/Dialect/Test/TestTraits.h +++ b/mlir/test/lib/Dialect/Test/TestTraits.h @@ -14,6 +14,7 @@ #define MLIR_TESTTRAITS_H #include "mlir/IR/Attributes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Types.h" namespace mlir { @@ -34,6 +35,18 @@ : public AttributeTrait::TraitBase {}; } // namespace AttributeTrait + +namespace OpTrait { + +template +class TraitA : public OpTrait::TraitBase {}; + +template +class TraitB + : public OpTrait::TraitBase> {}; + +} // namespace OpTrait } // namespace mlir #endif // MLIR_TESTTRAITS_H diff --git a/mlir/test/mlir-tblgen/op-trait-order.td b/mlir/test/mlir-tblgen/op-trait-order.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-trait-order.td @@ -0,0 +1,23 @@ +// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s +// RUN: mlir-tblgen -gen-op-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=INTERFACE_DEPS + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +class TEST_Op traits = []> : + Op; + +def TraitA : NativeOpTrait<"TraitA"> {} +def TraitB : NativeOpTrait<"TraitB"> {} + +// INTERFACE_DEPS: struct TestOpInterfaceTrait {{.*}}, ::mlir::OpTrait::TraitB, ::mlir::OpTrait::TraitA> +def TestOpInterface : OpInterface<"TestOpInterface"> { + let verify = [{ return dummyVerify($_op); }]; + let dependentTraits = [TraitB, TraitA]; +} + +// CHECK: VerifyTraitsOrder, ::mlir::OpTrait::TraitB, TestOpInterface::Trait +def OpTraitDependencyOp : TEST_Op<"op-trait-dependency", [TraitA, TraitB, TestOpInterface]> {} diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -50,6 +50,7 @@ void registerTestSpirvEntryPointABIPass(); void registerTestSpirvGLSLCanonicalizationPass(); void registerTestSpirvModuleCombinerPass(); +void registerTestTraitOrderPass(); void registerTestTraitsPass(); void registerTosaTestQuantUtilAPIPass(); void registerVectorizerTestPass(); @@ -139,6 +140,7 @@ registerTestSpirvEntryPointABIPass(); registerTestSpirvGLSLCanonicalizationPass(); registerTestSpirvModuleCombinerPass(); + registerTestTraitOrderPass(); registerTestTraitsPass(); registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -272,7 +272,8 @@ public: static void emitDecl(const Operator &op, raw_ostream &os, - const StaticVerifierFunctionEmitter &staticVerifierEmitter); + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + bool verifyTraitOrder = false, bool staticVerifyTraitOrder = false); static void emitDef(const Operator &op, raw_ostream &os, const StaticVerifierFunctionEmitter &staticVerifierEmitter); @@ -281,7 +282,8 @@ OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter); - void emitDecl(raw_ostream &os); + void emitDecl(raw_ostream &os, bool verifyTraitOrder = false, + bool staticVerifyTraitOrder = false); void emitDef(raw_ostream &os); // Generate methods for accessing the attribute names of this operation. @@ -593,8 +595,10 @@ } void OpEmitter::emitDecl( const Operator &op, raw_ostream &os, - const StaticVerifierFunctionEmitter &staticVerifierEmitter) { - OpEmitter(op, staticVerifierEmitter).emitDecl(os); + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + bool verifyTraitOrder, bool staticVerifyTraitOrder) { + OpEmitter(op, staticVerifierEmitter) + .emitDecl(os, verifyTraitOrder, staticVerifyTraitOrder); } void OpEmitter::emitDef( @@ -603,9 +607,12 @@ OpEmitter(op, staticVerifierEmitter).emitDef(os); } -void OpEmitter::emitDecl(raw_ostream &os) { +void OpEmitter::emitDecl(raw_ostream &os, bool verifyTraitOrder, + bool staticVerifyTraitOrder) { opClass.finalize(); opClass.writeDeclTo(os); + if (verifyTraitOrder) + opClass.writeTraitOrderVerification(os, staticVerifyTraitOrder); } void OpEmitter::emitDef(raw_ostream &os) { @@ -2688,7 +2695,7 @@ // Emits the opcode enum and op classes. static void emitOpClasses(const RecordKeeper &recordKeeper, const std::vector &defs, raw_ostream &os, - bool emitDecl) { + bool emitDecl, bool staticVerifyTraitOrder) { // First emit forward declaration for each class, this allows them to refer // to each others in traits for example. if (emitDecl) { @@ -2719,7 +2726,9 @@ os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os); - OpEmitter::emitDecl(op, os, staticVerifierEmitter); + OpEmitter::emitDecl(op, os, staticVerifierEmitter, + /*verifyTraitOrder=*/true, + /*staticVerifyTraitOrder=*/staticVerifyTraitOrder); } // Emit the TypeID explicit specialization to have a single definition. if (!op.getCppNamespace().empty()) @@ -2752,11 +2761,13 @@ [&os]() { os << ",\n"; }); } -static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { +static bool emitOpDecls(const RecordKeeper &recordKeeper, + bool staticVerifyTraitOrder, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os); std::vector defs = getRequestedOpDefinitions(recordKeeper); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); + emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true, + staticVerifyTraitOrder); return false; } @@ -2766,7 +2777,9 @@ std::vector defs = getRequestedOpDefinitions(recordKeeper); emitOpList(defs, os); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); + // Trait order verification is only done at emitting declaration. + emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false, + /*staticVerifyTraitOrder=*/false); return false; } @@ -2774,9 +2787,16 @@ static mlir::GenRegistration genOpDecls("gen-op-decls", "Generate op declarations", [](const RecordKeeper &records, raw_ostream &os) { - return emitOpDecls(records, os); + return emitOpDecls(records, /*staticVerifyTraitOrder=*/true, + os); }); +static mlir::GenRegistration genOpDeclsWithRuntimeTraitOrderCheck( + "gen-op-decls-with-runtime-trait-order-check", "Generate op declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitOpDecls(records, /*staticVerifyTraitOrder=*/false, os); + }); + static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", [](const RecordKeeper &records, raw_ostream &os) { diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -88,7 +88,7 @@ void emitModelDecl(Interface &interface); void emitModelMethodsDef(Interface &interface); void emitTraitDecl(Interface &interface, StringRef interfaceName, - StringRef interfaceTraitsName); + StringRef interfaceTraitsName, StringRef dependentTraits); void emitInterfaceDecl(Interface interface); /// The set of interface records to emit. @@ -376,12 +376,13 @@ void InterfaceGenerator::emitTraitDecl(Interface &interface, StringRef interfaceName, - StringRef interfaceTraitsName) { + StringRef interfaceTraitsName, + StringRef dependentTraits) { os << llvm::formatv(" template \n" " struct {0}Trait : public ::mlir::{2}<{0}," - " detail::{1}>::Trait<{3}> {{\n", + " detail::{1}{4}>::Trait<{3}> {{\n", interfaceName, interfaceTraitsName, interfaceBaseType, - valueTemplate); + valueTemplate, dependentTraits); // Insert the default implementation for any methods. bool isOpInterface = isa(interface); @@ -445,12 +446,25 @@ os << "\n} // namespace detail\n"; + std::string dependentTraits; + { + llvm::raw_string_ostream os(dependentTraits); + + for (auto &trait : interface.getDependentTraits()) { + if (const auto *nativeTrait = dyn_cast(&trait)) + os << ", " << nativeTrait->getFullyQualifiedTraitName(); + else if (const auto *interfaceTrait = + dyn_cast(&trait)) + os << ", " << interfaceTrait->getFullyQualifiedTraitName(); + } + } + // Emit the main interface class declaration. - os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" + os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}{4}> {\n" "public:\n" " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", interfaceName, interfaceName, interfaceTraitsName, - interfaceBaseType); + interfaceBaseType, dependentTraits); // Emit a utility wrapper trait class. os << llvm::formatv(" template \n" @@ -473,7 +487,7 @@ os << "};\n"; os << "namespace detail {\n"; - emitTraitDecl(interface, interfaceName, interfaceTraitsName); + emitTraitDecl(interface, interfaceName, interfaceTraitsName, dependentTraits); os << "}// namespace detail\n"; emitModelMethodsDef(interface);