diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td @@ -20,7 +20,7 @@ //===----------------------------------------------------------------------===// class PDL_Type - : TypeDef { + : TypeDef { let mnemonic = typeMnemonic; } diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -23,7 +23,7 @@ // Base class for Builtin dialect attributes. class Builtin_Attr - : AttrDef { + : AttrDef { let mnemonic = ?; } diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td --- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td +++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td @@ -17,7 +17,7 @@ // Base class for Builtin dialect location attributes. class Builtin_LocationAttr - : AttrDef { + : AttrDef { let cppClassName = name; let mnemonic = ?; } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -22,7 +22,7 @@ // Base class for Builtin dialect types. class Builtin_Type - : TypeDef { + : TypeDef { let mnemonic = ?; } @@ -65,8 +65,7 @@ //===----------------------------------------------------------------------===// // Base class for Builtin dialect float types. -class Builtin_FloatType : TypeDef { +class Builtin_FloatType : Builtin_Type { let extraClassDeclaration = [{ static }] # name # [{Type get(MLIRContext *context); }]; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1734,42 +1734,58 @@ class VariadicSuccessor : Successor; + //===----------------------------------------------------------------------===// -// OpTrait definitions +// Trait definitions //===----------------------------------------------------------------------===// -// OpTrait represents a trait regarding an op. -class OpTrait; +// Trait represents a trait regarding an attribute, operation, or type. +class Trait; -// NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The -// purpose to wrap around C++ symbol string with this class is to make -// traits specified for ops in TableGen less alien and more integrated. -class NativeOpTrait : OpTrait { +// NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap +// around C++ symbol string with this class is to make traits specified for +// entities in TableGen less alien and more integrated. +class NativeTrait : Trait { string trait = name; - string cppNamespace = "::mlir::OpTrait"; + string cppNamespace = "::mlir::" # entityType # "Trait"; } -// ParamNativeOpTrait corresponds to the template-parameterized traits in the -// C++ implementation. MLIR uses nested class templates to implement such -// traits leading to constructs of the form "TraitName::Impl". Use -// the value in `prop` as the trait name and the value in `params` as -// parameters to construct the native trait class name. -class ParamNativeOpTrait - : NativeOpTrait::Impl">; +// ParamNativeTrait corresponds to the template-parameterized traits in the C++ +// implementation. MLIR uses nested class templates to implement such traits +// leading to constructs of the form "TraitName::Impl". Use the +// value in `prop` as the trait name and the value in `params` as parameters to +// construct the native trait class name. +class ParamNativeTrait + : NativeTrait::Impl", entityType>; -// GenInternalOpTrait is an op trait that does not have direct C++ mapping but -// affects op definition generator internals, like how op builders and +// GenInternalTrait is a trait that does not have direct C++ mapping but affects +// an entities definition generator internals, like how operation builders and // operand/attribute/result getters are generated. -class GenInternalOpTrait : OpTrait { - string trait = "::mlir::OpTrait::" # prop; +class GenInternalTrait : Trait { + string trait = "::mlir::" # entityType # "Trait::" # prop; } -// PredOpTrait is an op trait implemented by way of a predicate on the op. -class PredOpTrait : OpTrait { +// PredTrait is a trait implemented by way of a predicate on an entity. +class PredTrait : Trait { string summary = descr; Pred predicate = pred; } +//===----------------------------------------------------------------------===// +// OpTrait definitions +//===----------------------------------------------------------------------===// + +// OpTrait represents a trait regarding an operation. +// TODO: Remove this class in favor of using Trait. +class OpTrait; + +// These classes are used to define operation specific traits. +class NativeOpTrait : NativeTrait, OpTrait; +class ParamNativeOpTrait + : ParamNativeTrait, OpTrait; +class GenInternalOpTrait : GenInternalTrait, OpTrait; +class PredOpTrait : PredTrait, OpTrait; + // Op defines an affine scope. def AffineScope : NativeOpTrait<"AffineScope">; // Op defines an automatic allocation scope. @@ -1895,23 +1911,28 @@ string defaultValue = value; } -// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in -// C++. The purpose to wrap around C++ symbol string with this class is to make +// InterfaceTrait corresponds to a specific 'Interface' class defined in C++. +// The purpose to wrap around C++ symbol string with this class is to make // interfaces specified for ops in TableGen less alien and more integrated. -class OpInterfaceTrait - : NativeOpTrait<""> { +class InterfaceTrait : NativeTrait<"", ""> { let trait = name # "::Trait"; let cppNamespace = ""; - // Specify the body of the verification function. `$_op` will be replaced with - // the operation being verified. - code verify = verifyBody; - // An optional code block containing extra declarations to place in the // interface trait declaration. code extraTraitClassDeclaration = ""; } +// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in +// C++. The purpose to wrap around C++ symbol string with this class is to make +// interfaces specified for ops in TableGen less alien and more integrated. +class OpInterfaceTrait + : InterfaceTrait, OpTrait { + // Specify the body of the verification function. `$_op` will be replaced with + // the operation being verified. + code verify = verifyBody; +} + // This class represents a single, optionally static, interface method. // Note: non-static interface methods have an implicit parameter, either // $_op/$_attr/$_type corresponding to an instance of the derived value. @@ -1967,39 +1988,52 @@ } // AttrInterface represents an interface registered to an attribute. -class AttrInterface : Interface { - // An optional code block containing extra declarations to place in the - // interface trait declaration. - code extraTraitClassDeclaration = ""; -} +class AttrInterface : Interface, InterfaceTrait; // OpInterface represents an interface registered to an operation. class OpInterface : Interface, OpInterfaceTrait; // TypeInterface represents an interface registered to a type. -class TypeInterface : Interface { - // An optional code block containing extra declarations to place in the - // interface trait declaration. - code extraTraitClassDeclaration = ""; -} +class TypeInterface : Interface, InterfaceTrait; -// Whether to declare the op interface methods in the op's header. This class -// simply wraps an OpInterface but is used to indicate that the method +// Whether to declare the interface methods in the user entity's header. This +// class simply wraps an Interface but is used to indicate that the method // declarations should be generated. This class takes an optional set of methods // that should have declarations generated even if the method has a default // implementation. +class DeclareInterfaceMethods overridenMethods = []> { + // This field contains a set of method names that should always have their + // declarations generated. This allows for generating declarations for + // methods with default implementations that need to be overridden. + list alwaysOverriddenMethods = overridenMethods; +} +class DeclareAttrInterfaceMethods overridenMethods = []> + : DeclareInterfaceMethods, + AttrInterface { + let description = interface.description; + let cppClassName = interface.cppClassName; + let cppNamespace = interface.cppNamespace; + let methods = interface.methods; +} class DeclareOpInterfaceMethods overridenMethods = []> - : OpInterface { + : DeclareInterfaceMethods, + OpInterface { + let description = interface.description; + let cppClassName = interface.cppClassName; + let cppNamespace = interface.cppNamespace; + let methods = interface.methods; +} +class DeclareTypeInterfaceMethods overridenMethods = []> + : DeclareInterfaceMethods, + TypeInterface { let description = interface.description; let cppClassName = interface.cppClassName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; - - // This field contains a set of method names that should always have their - // declarations generated. This allows for generating declarations for - // methods with default implementations that need to be overridden. - list alwaysOverriddenMethods = overridenMethods; } //===----------------------------------------------------------------------===// @@ -2609,7 +2643,8 @@ // Define a new attribute or type, named `name`, that inherits from the given // C++ base class. -class AttrOrTypeDef { +class AttrOrTypeDef defTraits, + string baseCppClass> { // The name of the C++ base class to use for this def. string cppBaseClassName = baseCppClass; @@ -2664,6 +2699,9 @@ // Note that builders should only be provided when a def has parameters. list builders = ?; + // The list of traits attached to this def. + list traits = defTraits; + // Use the lowercased name as the keyword for parsing/printing. Specify only // if you want tblgen to generate declarations and/or definitions of // the printer/parser. @@ -2692,10 +2730,10 @@ // Define a new attribute, named `name`, belonging to `dialect` that inherits // from the given C++ base class. -class AttrDef traits = [], string baseCppClass = "::mlir::Attribute"> : DialectAttr, /*descr*/"">, - AttrOrTypeDef<"Attr", name, baseCppClass> { + AttrOrTypeDef<"Attr", name, traits, baseCppClass> { // The name of the C++ Attribute class. string cppClassName = name # "Attr"; @@ -2728,10 +2766,10 @@ // Define a new type, named `name`, belonging to `dialect` that inherits from // the given C++ base class. -class TypeDef traits = [], string baseCppClass = "::mlir::Type"> : DialectType, /*descr*/"", name # "Type">, - AttrOrTypeDef<"Type", name, baseCppClass> { + AttrOrTypeDef<"Type", name, traits, baseCppClass> { // A constant builder provided when the type has no parameters. let builderCall = !if(!empty(parameters), "$_builder.getType<" # dialect.cppNamespace # diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -16,6 +16,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Builder.h" +#include "mlir/TableGen/Trait.h" namespace llvm { class DagInit; @@ -120,6 +121,9 @@ // Returns the builders of this def. ArrayRef getBuilders() const { return builders; } + // Returns the traits of this def. + ArrayRef getTraits() const { return traits; } + // Returns whether two AttrOrTypeDefs are equal by checking the equality of // the underlying record. bool operator==(const AttrOrTypeDef &other) const; @@ -136,8 +140,11 @@ protected: const llvm::Record *def; - // The builders of this type definition. + // The builders of this definition. SmallVector builders; + + // The traits of this definition. + SmallVector traits; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -18,9 +18,9 @@ #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Builder.h" #include "mlir/TableGen/Dialect.h" -#include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Region.h" #include "mlir/TableGen/Successor.h" +#include "mlir/TableGen/Trait.h" #include "mlir/TableGen/Type.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SmallVector.h" @@ -176,9 +176,7 @@ var_decorator_range getArgDecorators(int index) const; // Returns the trait wrapper for the given MLIR C++ `trait`. - // TODO: We should add a C++ wrapper class for TableGen OpTrait instead of - // requiring the raw MLIR trait here. - const OpTrait *getTrait(llvm::StringRef trait) const; + const Trait *getTrait(llvm::StringRef trait) const; // Regions. using const_region_iterator = const NamedRegion *; @@ -209,7 +207,7 @@ unsigned getNumVariadicSuccessors() const; // Trait. - using const_trait_iterator = const OpTrait *; + using const_trait_iterator = const Trait *; const_trait_iterator trait_begin() const; const_trait_iterator trait_end() const; llvm::iterator_range getTraits() const; @@ -325,7 +323,7 @@ SmallVector successors; // The traits of the op. - SmallVector traits; + SmallVector traits; // The regions of this op. SmallVector regions; diff --git a/mlir/include/mlir/TableGen/SideEffects.h b/mlir/include/mlir/TableGen/SideEffects.h --- a/mlir/include/mlir/TableGen/SideEffects.h +++ b/mlir/include/mlir/TableGen/SideEffects.h @@ -41,7 +41,7 @@ // This class represents an instance of a side effect interface applied to an // operation. This is a wrapper around an OpInterfaceTrait that also includes // the effects that are applied. -class SideEffectTrait : public InterfaceOpTrait { +class SideEffectTrait : public InterfaceTrait { public: // Return the effects that are attached to the side effect interface. Operator::var_decorator_range getEffects() const; @@ -49,7 +49,7 @@ // Return the name of the base C++ effect. StringRef getBaseEffectName() const; - static bool classof(const OpTrait *t); + static bool classof(const Trait *t); }; } // end namespace tblgen diff --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/Trait.h rename from mlir/include/mlir/TableGen/OpTrait.h rename to mlir/include/mlir/TableGen/Trait.h --- a/mlir/include/mlir/TableGen/OpTrait.h +++ b/mlir/include/mlir/TableGen/Trait.h @@ -1,4 +1,4 @@ -//===- OpTrait.h - OpTrait wrapper class ------------------------*- C++ -*-===// +//===- Trait.h - Trait wrapper class ----------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// // -// OpTrait wrapper to simplify using TableGen Record defining an MLIR OpTrait. +// Trait wrapper to simplify using TableGen Record defining an MLIR Trait. // //===----------------------------------------------------------------------===// -#ifndef MLIR_TABLEGEN_OPTRAIT_H_ -#define MLIR_TABLEGEN_OPTRAIT_H_ +#ifndef MLIR_TABLEGEN_TRAIT_H_ +#define MLIR_TABLEGEN_TRAIT_H_ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" @@ -25,28 +25,28 @@ namespace mlir { namespace tblgen { -struct OpInterface; +class Interface; -// Wrapper class with helper methods for accessing OpTrait constraints defined -// in TableGen. -class OpTrait { +// Wrapper class with helper methods for accessing Trait constraints defined in +// TableGen. +class Trait { public: - // Discriminator for kinds of op traits. + // Discriminator for kinds of traits. enum class Kind { - // OpTrait corresponding to C++ class. + // Trait corresponding to C++ class. Native, - // OpTrait corresponding to predicate on operation. + // Trait corresponding to a predicate. Pred, - // OpTrait controlling op definition generator internals. + // Trait controlling definition generator internals. Internal, - // OpTrait corresponding to OpInterface. + // Trait corresponding to an Interface. Interface }; - explicit OpTrait(Kind kind, const llvm::Record *def); + explicit Trait(Kind kind, const llvm::Record *def); - // Returns an OpTrait corresponding to the init provided. - static OpTrait create(const llvm::Init *init); + // Returns an Trait corresponding to the init provided. + static Trait create(const llvm::Init *init); Kind getKind() const { return kind; } @@ -59,17 +59,17 @@ Kind kind; }; -// OpTrait corresponding to a native C++ OpTrait. -class NativeOpTrait : public OpTrait { +// Trait corresponding to a native C++ Trait. +class NativeTrait : public Trait { public: // Returns the trait corresponding to a C++ trait class. - std::string getTrait() const; + std::string getFullyQualifiedTraitName() const; - static bool classof(const OpTrait *t) { return t->getKind() == Kind::Native; } + static bool classof(const Trait *t) { return t->getKind() == Kind::Native; } }; -// OpTrait corresponding to a predicate on the operation. -class PredOpTrait : public OpTrait { +// Trait corresponding to a predicate on the operation. +class PredTrait : public Trait { public: // Returns the template for constructing the predicate. std::string getPredTemplate() const; @@ -77,30 +77,28 @@ // Returns the description of what the predicate is verifying. StringRef getSummary() const; - static bool classof(const OpTrait *t) { return t->getKind() == Kind::Pred; } + static bool classof(const Trait *t) { return t->getKind() == Kind::Pred; } }; -// OpTrait controlling op definition generator internals. -class InternalOpTrait : public OpTrait { +// Trait controlling op definition generator internals. +class InternalTrait : public Trait { public: // Returns the trait controlling op definition generator internals. - StringRef getTrait() const; + StringRef getFullyQualifiedTraitName() const; - static bool classof(const OpTrait *t) { - return t->getKind() == Kind::Internal; - } + static bool classof(const Trait *t) { return t->getKind() == Kind::Internal; } }; -// OpTrait corresponding to an OpInterface on the operation. -class InterfaceOpTrait : public OpTrait { +// Trait corresponding to an OpInterface on the operation. +class InterfaceTrait : public Trait { public: - // Returns member function definitions corresponding to the trait, - OpInterface getOpInterface() const; + // Returns interface corresponding to the trait. + Interface getInterface() const; // Returns the trait corresponding to a C++ trait class. - std::string getTrait() const; + std::string getFullyQualifiedTraitName() const; - static bool classof(const OpTrait *t) { + static bool classof(const Trait *t) { return t->getKind() == Kind::Interface; } @@ -115,4 +113,4 @@ } // end namespace tblgen } // end namespace mlir -#endif // MLIR_TABLEGEN_OPTRAIT_H_ +#endif // MLIR_TABLEGEN_TRAIT_H_ diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -8,6 +8,7 @@ #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/Dialect.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -46,6 +47,15 @@ builders.emplace_back(builder); } } + + // Populate the traits. + if (auto *traitList = def->getValueAsListInit("traits")) { + SmallPtrSet traitSet; + traits.reserve(traitSet.size()); + for (auto *traitInit : *traitList) + if (traitSet.insert(traitInit).second) + traits.push_back(Trait::create(traitInit)); + } } Dialect AttrOrTypeDef::getDialect() const { diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -19,13 +19,13 @@ Interfaces.cpp Operator.cpp OpClass.cpp - OpTrait.cpp Pass.cpp Pattern.cpp Predicate.cpp Region.cpp SideEffects.cpp Successor.cpp + Trait.cpp Type.cpp DISABLE_LLVM_LINK_LLVM_DYLIB diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp deleted file mode 100644 --- a/mlir/lib/TableGen/OpTrait.cpp +++ /dev/null @@ -1,75 +0,0 @@ -//===- OpTrait.cpp - OpTrait class ----------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// OpTrait wrapper to simplify using TableGen Record defining a MLIR OpTrait. -// -//===----------------------------------------------------------------------===// - -#include "mlir/TableGen/OpTrait.h" -#include "mlir/TableGen/Interfaces.h" -#include "mlir/TableGen/Predicate.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/TableGen/Error.h" -#include "llvm/TableGen/Record.h" - -using namespace mlir; -using namespace mlir::tblgen; - -OpTrait OpTrait::create(const llvm::Init *init) { - auto def = cast(init)->getDef(); - if (def->isSubClassOf("PredOpTrait")) - return OpTrait(Kind::Pred, def); - if (def->isSubClassOf("GenInternalOpTrait")) - return OpTrait(Kind::Internal, def); - if (def->isSubClassOf("OpInterfaceTrait")) - return OpTrait(Kind::Interface, def); - assert(def->isSubClassOf("NativeOpTrait")); - return OpTrait(Kind::Native, def); -} - -OpTrait::OpTrait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {} - -std::string NativeOpTrait::getTrait() const { - llvm::StringRef trait = def->getValueAsString("trait"); - llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace"); - return cppNamespace.empty() ? trait.str() - : (cppNamespace + "::" + trait).str(); -} - -llvm::StringRef InternalOpTrait::getTrait() const { - return def->getValueAsString("trait"); -} - -std::string PredOpTrait::getPredTemplate() const { - auto pred = Pred(def->getValueInit("predicate")); - return pred.getCondition(); -} - -llvm::StringRef PredOpTrait::getSummary() const { - return def->getValueAsString("summary"); -} - -OpInterface InterfaceOpTrait::getOpInterface() const { - return OpInterface(def); -} - -std::string InterfaceOpTrait::getTrait() const { - llvm::StringRef trait = def->getValueAsString("trait"); - llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace"); - return cppNamespace.empty() ? trait.str() - : (cppNamespace + "::" + trait).str(); -} - -bool InterfaceOpTrait::shouldDeclareMethods() const { - return def->isSubClassOf("DeclareOpInterfaceMethods"); -} - -std::vector InterfaceOpTrait::getAlwaysDeclaredMethods() const { - return def->getValueAsListOfStrings("alwaysOverriddenMethods"); -} diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -11,8 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Operator.h" -#include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Predicate.h" +#include "mlir/TableGen/Trait.h" #include "mlir/TableGen/Type.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/STLExtras.h" @@ -158,17 +158,17 @@ return *arg->getValueAsListInit("decorators"); } -const OpTrait *Operator::getTrait(StringRef trait) const { +const Trait *Operator::getTrait(StringRef trait) const { for (const auto &t : traits) { - if (const auto *opTrait = dyn_cast(&t)) { - if (opTrait->getTrait() == trait) - return opTrait; - } else if (const auto *opTrait = dyn_cast(&t)) { - if (opTrait->getTrait() == trait) - return opTrait; - } else if (const auto *opTrait = dyn_cast(&t)) { - if (opTrait->getTrait() == trait) - return opTrait; + if (const auto *traitDef = dyn_cast(&t)) { + if (traitDef->getFullyQualifiedTraitName() == trait) + return traitDef; + } else if (const auto *traitDef = dyn_cast(&t)) { + if (traitDef->getFullyQualifiedTraitName() == trait) + return traitDef; + } else if (const auto *traitDef = dyn_cast(&t)) { + if (traitDef->getFullyQualifiedTraitName() == trait) + return traitDef; } } return nullptr; @@ -314,7 +314,7 @@ return found; }; - for (const OpTrait &trait : traits) { + for (const Trait &trait : traits) { const llvm::Record &def = trait.getDef(); // If the infer type op interface was manually added, then treat it as // intention that the op needs special handling. @@ -323,8 +323,8 @@ if (def.isSubClassOf( llvm::formatv("{0}::Trait", inferTypeOpInterface).str())) return; - if (const auto *opTrait = dyn_cast(&trait)) - if (&opTrait->getDef() == inferTrait) + if (const auto *traitDef = dyn_cast(&trait)) + if (&traitDef->getDef() == inferTrait) return; if (!def.isSubClassOf("AllTypesMatch")) @@ -344,7 +344,7 @@ // If the types could be computed, then add type inference trait. if (allResultsHaveKnownTypes) - traits.push_back(OpTrait::create(inferTrait->getDefInit())); + traits.push_back(Trait::create(inferTrait->getDefInit())); } void Operator::populateOpStructure() { @@ -489,7 +489,7 @@ for (auto *traitInit : *traitList) { // Keep traits in the same order while skipping over duplicates. if (traitSet.insert(traitInit).second) - traits.push_back(OpTrait::create(traitInit)); + traits.push_back(Trait::create(traitInit)); } } diff --git a/mlir/lib/TableGen/SideEffects.cpp b/mlir/lib/TableGen/SideEffects.cpp --- a/mlir/lib/TableGen/SideEffects.cpp +++ b/mlir/lib/TableGen/SideEffects.cpp @@ -53,6 +53,6 @@ return def->getValueAsString("baseEffectName"); } -bool SideEffectTrait::classof(const OpTrait *t) { +bool SideEffectTrait::classof(const Trait *t) { return t->getDef().isSubClassOf("SideEffectsTraitBase"); } diff --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/Trait.cpp @@ -0,0 +1,93 @@ +//===- Trait.cpp ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Trait wrapper to simplify using TableGen Record defining a MLIR Trait. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Trait.h" +#include "mlir/TableGen/Interfaces.h" +#include "mlir/TableGen/Predicate.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// Trait +//===----------------------------------------------------------------------===// + +Trait Trait::create(const llvm::Init *init) { + auto def = cast(init)->getDef(); + if (def->isSubClassOf("PredTrait")) + return Trait(Kind::Pred, def); + if (def->isSubClassOf("GenInternalTrait")) + return Trait(Kind::Internal, def); + if (def->isSubClassOf("InterfaceTrait")) + return Trait(Kind::Interface, def); + assert(def->isSubClassOf("NativeTrait")); + return Trait(Kind::Native, def); +} + +Trait::Trait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {} + +//===----------------------------------------------------------------------===// +// NativeTrait +//===----------------------------------------------------------------------===// + +std::string NativeTrait::getFullyQualifiedTraitName() const { + llvm::StringRef trait = def->getValueAsString("trait"); + llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace"); + return cppNamespace.empty() ? trait.str() + : (cppNamespace + "::" + trait).str(); +} + +//===----------------------------------------------------------------------===// +// InternalTrait +//===----------------------------------------------------------------------===// + +llvm::StringRef InternalTrait::getFullyQualifiedTraitName() const { + return def->getValueAsString("trait"); +} + +//===----------------------------------------------------------------------===// +// PredTrait +//===----------------------------------------------------------------------===// + +std::string PredTrait::getPredTemplate() const { + auto pred = Pred(def->getValueInit("predicate")); + return pred.getCondition(); +} + +llvm::StringRef PredTrait::getSummary() const { + return def->getValueAsString("summary"); +} + +//===----------------------------------------------------------------------===// +// InterfaceTrait +//===----------------------------------------------------------------------===// + +Interface InterfaceTrait::getInterface() const { return Interface(def); } + +std::string InterfaceTrait::getFullyQualifiedTraitName() const { + llvm::StringRef trait = def->getValueAsString("trait"); + llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace"); + return cppNamespace.empty() ? trait.str() + : (cppNamespace + "::" + trait).str(); +} + +bool InterfaceTrait::shouldDeclareMethods() const { + return def->isSubClassOf("DeclareInterfaceMethods"); +} + +std::vector InterfaceTrait::getAlwaysDeclaredMethods() const { + return def->getValueAsListOfStrings("alwaysOverriddenMethods"); +} diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -14,6 +14,7 @@ // A type interface used to test the ODS generation of type interfaces. def TestTypeInterface : TypeInterface<"TestTypeInterface"> { + let cppNamespace = "::mlir::test"; let methods = [ InterfaceMethod<"Prints the type name.", "void", "printTypeA", (ins "Location":$loc), [{ diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -15,9 +15,11 @@ // To get the test dialect def. include "TestOps.td" +include "mlir/Interfaces/DataLayoutInterfaces.td" // All of the types will extend this class. -class Test_Type : TypeDef { } +class Test_Type traits = []> + : TypeDef; def SimpleTypeA : Test_Type<"SimpleA"> { let mnemonic = "smpla"; @@ -151,4 +153,27 @@ let mnemonic = "struct"; } +def TestType : Test_Type<"Test", [ + DeclareTypeInterfaceMethods +]> { + let mnemonic = "test_type"; +} + +def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [ + DeclareTypeInterfaceMethods +]> { + let mnemonic = "test_type_with_layout"; + let parameters = (ins "unsigned":$key); + let extraClassDeclaration = [{ + LogicalResult verifyEntries(DataLayoutEntryListRef params, + Location loc) const; + + private: + unsigned extractKind(DataLayoutEntryListRef params, + StringRef expectedKind) const; + + public: + }]; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -41,25 +41,14 @@ } // namespace test } // namespace mlir +#include "TestTypeInterfaces.h.inc" + #define GET_TYPEDEF_CLASSES #include "TestTypeDefs.h.inc" namespace mlir { namespace test { -#include "TestTypeInterfaces.h.inc" - -/// This class is a simple test type that uses a generated interface. -struct TestType : public Type::TypeBase { - using Base::Base; - - /// Provide a definition for the necessary interface methods. - void printTypeC(Location loc) const { - emitRemark(loc) << *this << " - TestC"; - } -}; - /// Storage for simple named recursive types, where the type is identified by /// its name and can "contain" another type, including itself. struct TestRecursiveTypeStorage : public TypeStorage { @@ -108,62 +97,6 @@ StringRef getName() { return getImpl()->name; } }; -struct TestTypeWithLayoutStorage : public TypeStorage { - using KeyTy = unsigned; - - explicit TestTypeWithLayoutStorage(unsigned key) : key(key) {} - bool operator==(const KeyTy &other) const { return other == key; } - - static TestTypeWithLayoutStorage *construct(TypeStorageAllocator &allocator, - const KeyTy &key) { - return new (allocator.allocate()) - TestTypeWithLayoutStorage(key); - } - - unsigned key; -}; - -class TestTypeWithLayout - : public Type::TypeBase { -public: - using Base::Base; - - static TestTypeWithLayout get(MLIRContext *ctx, unsigned key) { - return Base::get(ctx, key); - } - - unsigned getKey() { return getImpl()->key; } - - unsigned getTypeSizeInBits(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const { - return extractKind(params, "size"); - } - - unsigned getABIAlignment(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const { - return extractKind(params, "alignment"); - } - - unsigned getPreferredAlignment(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const { - return extractKind(params, "preferred"); - } - - bool areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const { - unsigned old = extractKind(oldLayout, "alignment"); - return old == 1 || extractKind(newLayout, "alignment") <= old; - } - - LogicalResult verifyEntries(DataLayoutEntryListRef params, - Location loc) const; - -private: - unsigned extractKind(DataLayoutEntryListRef params, - StringRef expectedKind) const; -}; - } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -58,6 +58,34 @@ } } +// The functions don't need to be in the header file, but need to be in the mlir +// namespace. Declare them here, then define them immediately below. Separating +// the declaration and definition adheres to the LLVM coding standards. +namespace mlir { +namespace test { +// FieldInfo is used as part of a parameter, so equality comparison is +// compulsory. +static bool operator==(const FieldInfo &a, const FieldInfo &b); +// FieldInfo is used as part of a parameter, so a hash will be computed. +static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT +} // namespace test +} // namespace mlir + +// FieldInfo is used as part of a parameter, so equality comparison is +// compulsory. +static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) { + return a.name == b.name && a.type == b.type; +} + +// FieldInfo is used as part of a parameter, so a hash will be computed. +static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT + return llvm::hash_combine(fi.name, fi.type); +} + +//===----------------------------------------------------------------------===// +// CompoundAType +//===----------------------------------------------------------------------===// + Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) { int widthOfSomething; Type oneType; @@ -87,29 +115,9 @@ printer << "]>"; } -// The functions don't need to be in the header file, but need to be in the mlir -// namespace. Declare them here, then define them immediately below. Separating -// the declaration and definition adheres to the LLVM coding standards. -namespace mlir { -namespace test { -// FieldInfo is used as part of a parameter, so equality comparison is -// compulsory. -static bool operator==(const FieldInfo &a, const FieldInfo &b); -// FieldInfo is used as part of a parameter, so a hash will be computed. -static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT -} // namespace test -} // namespace mlir - -// FieldInfo is used as part of a parameter, so equality comparison is -// compulsory. -static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) { - return a.name == b.name && a.type == b.type; -} - -// FieldInfo is used as part of a parameter, so a hash will be computed. -static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT - return llvm::hash_combine(fi.name, fi.type); -} +//===----------------------------------------------------------------------===// +// TestIntegerType +//===----------------------------------------------------------------------===// // Example type validity checker. LogicalResult @@ -122,18 +130,58 @@ } //===----------------------------------------------------------------------===// -// Tablegen Generated Definitions +// TestType //===----------------------------------------------------------------------===// -#define GET_TYPEDEF_CLASSES -#include "TestTypeDefs.cpp.inc" +void TestType::printTypeC(Location loc) const { + emitRemark(loc) << *this << " - TestC"; +} + +//===----------------------------------------------------------------------===// +// TestTypeWithLayout +//===----------------------------------------------------------------------===// + +Type TestTypeWithLayoutType::parse(MLIRContext *ctx, DialectAsmParser &parser) { + unsigned val; + if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater()) + return Type(); + return TestTypeWithLayoutType::get(ctx, val); +} + +void TestTypeWithLayoutType::print(DialectAsmPrinter &printer) const { + printer << "test_type_with_layout<" << getKey() << ">"; +} + +unsigned +TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + return extractKind(params, "size"); +} -LogicalResult TestTypeWithLayout::verifyEntries(DataLayoutEntryListRef params, - Location loc) const { +unsigned +TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + return extractKind(params, "alignment"); +} + +unsigned TestTypeWithLayoutType::getPreferredAlignment( + const DataLayout &dataLayout, DataLayoutEntryListRef params) const { + return extractKind(params, "preferred"); +} + +bool TestTypeWithLayoutType::areCompatible( + DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const { + unsigned old = extractKind(oldLayout, "alignment"); + return old == 1 || extractKind(newLayout, "alignment") <= old; +} + +LogicalResult +TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params, + Location loc) const { for (DataLayoutEntryInterface entry : params) { // This is for testing purposes only, so assert well-formedness. assert(entry.isTypeEntry() && "unexpected identifier entry"); - assert(entry.getKey().get().isa() && + assert(entry.getKey().get().isa() && "wrong type passed in"); auto array = entry.getValue().dyn_cast(); assert(array && array.getValue().size() == 2 && @@ -149,8 +197,8 @@ return success(); } -unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params, - StringRef expectedKind) const { +unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params, + StringRef expectedKind) const { for (DataLayoutEntryInterface entry : params) { ArrayRef pair = entry.getValue().cast().getValue(); StringRef kind = pair.front().cast().getValue(); @@ -160,12 +208,19 @@ return 1; } +//===----------------------------------------------------------------------===// +// Tablegen Generated Definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.cpp.inc" + //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// void TestDialect::registerTypes() { - addTypes(); @@ -183,17 +238,6 @@ if (parseResult.hasValue()) return genType; } - if (typeTag == "test_type") - return TestType::get(parser.getBuilder().getContext()); - - if (typeTag == "test_type_with_layout") { - unsigned val; - if (parser.parseLess() || parser.parseInteger(val) || - parser.parseGreater()) { - return Type(); - } - return TestTypeWithLayout::get(parser.getBuilder().getContext(), val); - } if (typeTag != "test_rec") { parser.emitError(parser.getNameLoc()) << "unknown type!"; @@ -234,15 +278,6 @@ llvm::SetVector &stack) { if (succeeded(generatedTypePrinter(type, printer))) return; - if (type.isa()) { - printer << "test_type"; - return; - } - - if (auto t = type.dyn_cast()) { - printer << "test_type_with_layout<" << t.getKey() << ">"; - return; - } auto rec = type.cast(); printer << "test_rec<" << rec.getName(); diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -78,16 +78,16 @@ // DEF: CompoundAAttrStorage ( // DEF-NEXT: : ::mlir::AttributeStorage(inner), -// DEF: bool operator==(const KeyTy &key) const { -// DEF-NEXT: if (!(widthOfSomething == std::get<0>(key))) +// DEF: bool operator==(const KeyTy &tblgenKey) const { +// DEF-NEXT: if (!(widthOfSomething == std::get<0>(tblgenKey))) // DEF-NEXT: return false; -// DEF-NEXT: if (!(exampleTdType == std::get<1>(key))) +// DEF-NEXT: if (!(exampleTdType == std::get<1>(tblgenKey))) // DEF-NEXT: return false; -// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(key)))) +// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(tblgenKey)))) // DEF-NEXT: return false; -// DEF-NEXT: if (!(dims == std::get<3>(key))) +// DEF-NEXT: if (!(dims == std::get<3>(tblgenKey))) // DEF-NEXT: return false; -// DEF-NEXT: if (!(getType() == std::get<4>(key))) +// DEF-NEXT: if (!(getType() == std::get<4>(tblgenKey))) // DEF-NEXT: return false; // DEF-NEXT: return true; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -11,8 +11,10 @@ #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/TableGenBackend.h" @@ -208,28 +210,29 @@ /// {1}: The name of the type base class. /// {2}: The name of the base value type, e.g. Attribute or Type. /// {3}: The tablegen record type prefix, e.g. Attr or Type. +/// {4}: The traits of the def class. static const char *const defDeclSingletonBeginStr = R"( - class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage> {{ + class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage{4}> {{ public: /// Inherit some necessary constructors from '{3}Base'. using Base::Base; )"; -/// The code block for the start of a typeDef class declaration -- parametric -/// case. +/// The code block for the start of a class declaration -- parametric case. /// -/// {0}: The name of the typeDef class. -/// {1}: The name of the type base class. -/// {2}: The typeDef storage class namespace. +/// {0}: The name of the def class. +/// {1}: The name of the base class. +/// {2}: The def storage class namespace. /// {3}: The storage class name. /// {4}: The name of the base value type, e.g. Attribute or Type. /// {5}: The tablegen record type prefix, e.g. Attr or Type. +/// {6}: The traits of the def class. static const char *const defDeclParametricBeginStr = R"( namespace {2} { struct {3}; } // end namespace {2} class {0} : public ::mlir::{4}::{5}Base<{0}, {1}, - {2}::{3}> {{ + {2}::{3}{6}> {{ public: /// Inherit some necessary constructors from '{5}Base'. using Base::Base; @@ -309,19 +312,71 @@ } } +static void emitInterfaceMethodDecls(const InterfaceTrait *trait, + raw_ostream &os) { + Interface interface = trait->getInterface(); + + // Get the set of methods that should always be declared. + auto alwaysDeclaredMethodsVec = trait->getAlwaysDeclaredMethods(); + llvm::StringSet<> alwaysDeclaredMethods; + alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), + alwaysDeclaredMethodsVec.end()); + + for (const InterfaceMethod &method : interface.getMethods()) { + // Don't declare if the method has a body. + if (method.getBody()) + continue; + // Don't declare if the method has a default implementation and the def + // didn't request that it always be declared. + if (method.getDefaultImplementation() && + !alwaysDeclaredMethods.count(method.getName())) + continue; + + // Emit the method declaration. + os << " " << (method.isStatic() ? "static " : "") + << method.getReturnType() << " " << method.getName() << "("; + llvm::interleaveComma(method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); + os << ")" << (method.isStatic() ? "" : " const") << ";\n"; + } +} + void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) { SmallVector params; def.getParameters(params); + // Build the trait list for this def. + std::vector traitList; + StringSet<> traitSet; + for (const Trait &baseTrait : def.getTraits()) { + std::string traitStr; + if (const auto *trait = dyn_cast(&baseTrait)) + traitStr = trait->getFullyQualifiedTraitName(); + else if (const auto *trait = dyn_cast(&baseTrait)) + traitStr = trait->getFullyQualifiedTraitName(); + else + llvm_unreachable("unexpected Attribute/Type trait type"); + + if (traitSet.insert(traitStr).second) + traitList.emplace_back(std::move(traitStr)); + } + std::string traitStr; + if (!traitList.empty()) + traitStr = ", " + llvm::join(traitList, ", "); + // Emit the beginning string template: either the singleton or parametric // template. if (def.getNumParameters() == 0) { os << formatv(defDeclSingletonBeginStr, def.getCppClassName(), - def.getCppBaseClassName(), valueType, defTypePrefix); + def.getCppBaseClassName(), valueType, defTypePrefix, + traitStr); } else { os << formatv(defDeclParametricBeginStr, def.getCppClassName(), def.getCppBaseClassName(), def.getStorageNamespace(), - def.getStorageClassName(), valueType, defTypePrefix); + def.getStorageClassName(), valueType, defTypePrefix, + traitStr); } // Emit the extra declarations first in case there's a definition in there. @@ -362,6 +417,14 @@ } } + // Emit any interface method declarations. + for (const Trait &trait : def.getTraits()) { + if (const auto *traitDef = dyn_cast(&trait)) { + if (traitDef->shouldDeclareMethods()) + emitInterfaceMethodDecls(traitDef, os); + } + } + // End the decl. os << " };\n"; } @@ -452,7 +515,7 @@ /// Define a construction method for creating a new instance of this /// storage. static {0} *construct(::mlir::{1}StorageAllocator &allocator, - const KeyTy &key) {{ + const KeyTy &tblgenKey) {{ )"; /// The storage class' constructor return template. @@ -558,7 +621,7 @@ paramInitializer, parameterTypeList, valueType); // * Emit the comparison method. - os << " bool operator==(const KeyTy &key) const {\n"; + os << " bool operator==(const KeyTy &tblgenKey) const {\n"; for (auto it : llvm::enumerate(params)) { os << " if (!("; @@ -566,7 +629,7 @@ bool isSelfType = isa(it.value()); FmtContext context; context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName()) - .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(key)"); + .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(tblgenKey)"); // Use the parameter specified comparator if possible, otherwise default to // operator==. @@ -577,13 +640,13 @@ os << " return true;\n }\n"; // * Emit the haskKey method. - os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n"; + os << " static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {\n"; // Extract each parameter from the key. os << " return ::llvm::hash_combine("; llvm::interleaveComma( llvm::seq(0, params.size()), os, - [&](unsigned it) { os << "std::get<" << it << ">(key)"; }); + [&](unsigned it) { os << "std::get<" << it << ">(tblgenKey)"; }); os << ");\n }\n"; // * Emit the construct method. @@ -592,7 +655,7 @@ // here and then they can write the definition elsewhere. if (def.hasStorageCustomConstructor()) { os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator " - "&allocator, const KeyTy &key);\n", + "&allocator, const KeyTy &tblgenKey);\n", def.getStorageClassName(), valueType); // Otherwise, generate one. @@ -601,7 +664,7 @@ os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(), valueType); for (unsigned i = 0, e = params.size(); i < e; ++i) { - os << formatv(" auto {0} = std::get<{1}>(key);\n", + os << formatv(" auto {0} = std::get<{1}>(tblgenKey);\n", params[i].getName(), i); } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -15,8 +15,8 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/OpClass.h" -#include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Trait.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -18,9 +18,9 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/OpClass.h" -#include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/SideEffects.h" +#include "mlir/TableGen/Trait.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Path.h" @@ -430,7 +430,7 @@ void genOpInterfaceMethods(); // Generate op interface methods for the given interface. - void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait); + void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait); // Generate op interface method for the given interface method. If // 'declaration' is true, generates a declaration, else a definition. @@ -1719,8 +1719,8 @@ } } -void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) { - auto interface = opTrait->getOpInterface(); +void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) { + Interface interface = opTrait->getInterface(); // Get the set of methods that should always be declared. auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); @@ -1757,7 +1757,7 @@ void OpEmitter::genOpInterfaceMethods() { for (const auto &trait : op.getTraits()) { - if (const auto *opTrait = dyn_cast(&trait)) + if (const auto *opTrait = dyn_cast(&trait)) if (opTrait->shouldDeclareMethods()) genOpInterfaceMethods(opTrait); } @@ -1866,9 +1866,9 @@ return; // Generate 'inferReturnTypes' method declaration using the interface method // declared in 'InferTypeOpInterface' op interface. - const auto *trait = dyn_cast( + const auto *trait = dyn_cast( op.getTrait("::mlir::InferTypeOpInterface::Trait")); - auto interface = trait->getOpInterface(); + Interface interface = trait->getInterface(); OpMethod *method = [&]() -> OpMethod * { for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { if (interfaceMethod.getName() == "inferReturnTypes") { @@ -1966,7 +1966,7 @@ genOperandResultVerifier(body, op.getResults(), "result"); for (auto &trait : op.getTraits()) { - if (auto *t = dyn_cast(&trait)) { + if (auto *t = dyn_cast(&trait)) { body << tgfmt(" if (!($0))\n " "return emitOpError(\"failed to verify that $1\");\n", &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), @@ -2187,10 +2187,10 @@ // Add the native and interface traits. for (const auto &trait : op.getTraits()) { - if (auto opTrait = dyn_cast(&trait)) - opClass.addTrait(opTrait->getTrait()); - else if (auto opTrait = dyn_cast(&trait)) - opClass.addTrait(opTrait->getTrait()); + if (auto opTrait = dyn_cast(&trait)) + opClass.addTrait(opTrait->getFullyQualifiedTraitName()); + else if (auto opTrait = dyn_cast(&trait)) + opClass.addTrait(opTrait->getFullyQualifiedTraitName()); } } @@ -2379,12 +2379,14 @@ // Verify a few traits first so that we can use // getODSOperands()/getODSResults() in the rest of the verifier. for (auto &trait : op.getTraits()) { - if (auto *t = dyn_cast(&trait)) { - if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") { + if (auto *t = dyn_cast(&trait)) { + if (t->getFullyQualifiedTraitName() == + "::mlir::OpTrait::AttrSizedOperandSegments") { body << formatv(checkAttrSizedValueSegmentsCode, "operand_segment_sizes", op.getNumOperands(), "operand"); - } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") { + } else if (t->getFullyQualifiedTraitName() == + "::mlir::OpTrait::AttrSizedResultSegments") { body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes", op.getNumResults(), "result"); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -12,8 +12,8 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/OpClass.h" -#include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Trait.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SetVector.h" @@ -445,18 +445,12 @@ operandTypes.resize(op.getNumOperands(), TypeResolution()); resultTypes.resize(op.getNumResults(), TypeResolution()); - hasImplicitTermTrait = - llvm::any_of(op.getTraits(), [](const OpTrait &trait) { - return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator"); - }); + hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) { + return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator"); + }); hasSingleBlockTrait = - hasImplicitTermTrait || - llvm::any_of(op.getTraits(), [](const OpTrait &trait) { - if (auto *native = dyn_cast(&trait)) - return native->getTrait() == "::mlir::OpTrait::SingleBlock"; - return false; - }); + hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock"); } /// Generate the operation parser from this format. @@ -2416,7 +2410,7 @@ // Check for any type traits that we can use for inferring types. llvm::StringMap variableTyResolver; - for (const OpTrait &trait : op.getTraits()) { + for (const Trait &trait : op.getTraits()) { const llvm::Record &def = trait.getDef(); if (def.isSubClassOf("AllTypesMatch")) { handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),