diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -21,16 +21,16 @@ ### Dialect Interfaces -Dialect interfaces are generally useful for transformation passes or -analyses that want to operate generically on a set of operations, -which might even be defined in different dialects. These -interfaces generally involve wide coverage over the entire dialect and are only -used for a handful of transformations/analyses. In these cases, registering the -interface directly on each operation is overly complex and cumbersome. The -interface is not core to the operation, just to the specific transformation. An -example of where this type of interface would be used is inlining. Inlining -generally queries high-level information about the operations within a dialect, -like legality and cost modeling, that often is not specific to one operation. +Dialect interfaces are generally useful for transformation passes or analyses +that want to operate generically on a set of attributes/operations/types, which +might even be defined in different dialects. These interfaces generally involve +wide coverage over the entire dialect and are only used for a handful of +transformations/analyses. In these cases, registering the interface directly on +each operation is overly complex and cumbersome. The interface is not core to +the operation, just to the specific transformation. An example of where this +type of interface would be used is inlining. Inlining generally queries +high-level information about the operations within a dialect, like legality and +cost modeling, that often is not specific to one operation. A dialect interface can be defined by inheriting from the CRTP base class `DialectInterfaceBase::Base`. This class provides the necessary utilities for @@ -106,24 +106,25 @@ ... ``` -### Operation Interfaces - -Operation interfaces, as the name suggests, are those registered at the -Operation level. These interfaces provide access to derived operations -by providing a virtual interface that must be implemented. As an example, the -`Linalg` dialect may implement an interface that provides general queries about -some of the dialects library operations. These queries may provide things like: -the number of parallel loops; the number of inputs and outputs; etc. - -Operation interfaces are defined by overriding the CRTP base class -`OpInterface`. This class takes, as a template parameter, a `Traits` class that -defines a `Concept` and a `Model` class. These classes provide an implementation -of concept-based polymorphism, where the Concept defines a set of virtual -methods that are overridden by the Model that is templated on the concrete -operation type. It is important to note that these classes should be pure in -that they contain no non-static data members. Operations that wish to override -this interface should add the provided trait `OpInterface<..>::Trait` upon -registration. +### Attribute/Operation/Type Interfaces + +Attribute/Operation/Type interfaces, as the names suggest, are those registered +at the level of a specific attribute/operation/type. These interfaces provide +access to derived objects by providing a virtual interface that must be +implemented. As an example, the `Linalg` dialect may implement an interface that +provides general queries about some of the dialects library operations. These +queries may provide things like: the number of parallel loops; the number of +inputs and outputs; etc. + +These interfaces are defined by overriding the CRTP base class `AttrInterface`, +`OpInterface`, or `TypeInterface` respectively. These classes take, as a +template parameter, a `Traits` class that defines a `Concept` and a `Model` +class. These classes provide an implementation of concept-based polymorphism, +where the Concept defines a set of virtual methods that are overridden by the +Model that is templated on the concrete object type. It is important to note +that these classes should be pure in that they contain no non-static data +members. Objects that wish to override this interface should add the provided +trait `*Interface<..>::Trait` to the trait list upon registration. ```c++ struct ExampleOpInterfaceTraits { @@ -182,8 +183,7 @@ Operation interfaces require a bit of boiler plate to connect all of the pieces together. The ODS(Operation Definition Specification) framework provides -simplified mechanisms for -[defining interfaces](OpDefinitions.md#operation-interfaces). +simplified mechanisms for [defining interfaces](OpDefinitions.md#interfaces). As an example, using the ODS framework would allow for defining the example interface above as: diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -346,20 +346,20 @@ template parameter to the `Op` class. They should be deriving from the `OpTrait` class. See [Constraints](#constraints) for more information. -### Operation interfaces +### Interfaces -[Operation interfaces](Interfaces.md#operation-interfaces) allow -operations to expose method calls without the -caller needing to know the exact operation type. Operation interfaces -defined in C++ can be accessed in the ODS framework via the -`OpInterfaceTrait` class. Aside from using pre-existing interfaces in -the C++ API, the ODS framework also provides a simplified mechanism -for defining such interfaces which removes much of the boilerplate -necessary. +[Interfaces](Interfaces.md#attribute-operation-type-interfaces) allow for +attributes, operations, and types to expose method calls without the caller +needing to know the derived type. Operation interfaces defined in C++ can be +accessed in the ODS framework via the `OpInterfaceTrait` class. Aside from using +pre-existing interfaces in the C++ API, the ODS framework also provides a +simplified mechanism for defining such interfaces which removes much of the +boilerplate necessary. -Providing a definition of the `OpInterface` class will auto-generate the C++ -classes for the interface. An `OpInterface` includes a name, for the C++ class, -a description, and a list of interface methods. +Providing a definition of the `AttrInterface`, `OpInterface`, or `TypeInterface` +class will auto-generate the C++ classes for the interface. An interface +includes a name, for the C++ class, a description, and a list of interface +methods. ```tablegen def MyInterface : OpInterface<"MyInterface"> { @@ -450,10 +450,11 @@ ]; } -// Interfaces can optionally be wrapped inside DeclareOpInterfaceMethods. This -// would result in autogenerating declarations for members `foo`, `bar` and -// `fooStatic`. Methods with bodies are not declared inside the op -// declaration but instead handled by the op interface trait directly. +// Operation interfaces can optionally be wrapped inside +// DeclareOpInterfaceMethods. This would result in autogenerating declarations +// for members `foo`, `bar` and `fooStatic`. Methods with bodies are not +// declared inside the op declaration but instead handled by the op interface +// trait directly. def OpWithInferTypeInterfaceOp : Op<... [DeclareOpInterfaceMethods]> { ... } @@ -465,9 +466,9 @@ [DeclareOpInterfaceMethods]> { ... } ``` -A verification method can also be specified on the `OpInterface` by setting -`verify`. Setting `verify` results in the generated trait having a `verifyTrait` -method that is applied to all operations implementing the trait. +Operation interfaces may also provide a verification method on `OpInterface` by +setting `verify`. Setting `verify` results in the generated trait having a +`verifyTrait` method that is applied to all operations implementing the trait. ### Builder methods diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -1,24 +1,28 @@ -# Operation Traits +# Traits [TOC] -MLIR allows for a truly open operation ecosystem, as any dialect may define -operations that suit a specific level of abstraction. `Traits` are a mechanism -which abstracts implementation details and properties that are common -across many different operations. `Traits` may be used to specify special -properties and constraints of the operation, including whether the operation has -side effects or whether its output has the same type as the input. Some examples -of traits are `Commutative`, `SingleResult`, `Terminator`, etc. See the more -[comprehensive list](#trait-list) below for more examples of what is possible. +MLIR allows for a truly open ecosystem, as any dialect may define attributes, +operations, and types that suit a specific level of abstraction. `Traits` are a +mechanism which abstracts implementation details and properties that are common +across many different attributes/operations/types/etc.. `Traits` may be used to +specify special properties and constraints of the object, including whether an +operation has side effects or that its output has the same type as the input. +Some examples of operation traits are `Commutative`, `SingleResult`, +`Terminator`, etc. See the more comprehensive list of +[operation traits](#operation-traits-list) below for more examples of what is +possible. ## Defining a Trait -Traits may be defined in C++ by inheriting from the -`OpTrait::TraitBase` class. This base class takes as -template parameters: +Traits may be defined in C++ by inheriting from the `TraitBase` class for the specific IR type. For attributes, this is +`AttributeTrait::TraitBase`. For operations, this is `OpTrait::TraitBase`. For +types, this is `TypeTrait::TraitBase`. This base class takes as template +parameters: * ConcreteType - - The concrete operation type that this trait was attached to. + - The concrete class type that this trait was attached to. * TraitType - The type of the trait class that is being defined, for use with the [`Curiously Recurring Template Pattern`](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern). @@ -28,11 +32,11 @@ ```c++ template -class MyTrait : public OpTrait::TraitBase { +class MyTrait : public TraitBase { }; ``` -Derived traits may also provide a `verifyTrait` hook, that is called when +Operation traits may also provide a `verifyTrait` hook, that is called when verifying the concrete operation. The trait verifiers will currently always be invoked before the main `Op::verify`. @@ -57,15 +61,15 @@ The above demonstrates the definition of a simple self-contained trait. It is also often useful to provide some static parameters to the trait to control its behavior. Given that the definition of the trait class is rigid, i.e. we must -have a single template argument for the concrete operation, the templates for -the parameters will need to be split out. An example is shown below: +have a single template argument for the concrete object, the templates for the +parameters will need to be split out. An example is shown below: ```c++ template class MyParametricTrait { public: template - class Impl : public OpTrait::TraitBase { + class Impl : public TraitBase { // Inside of 'Impl' we have full access to the template parameters // specified above. }; @@ -74,19 +78,28 @@ ## Attaching a Trait -Traits may be used when defining a derived operation type, by simply adding the -name of the trait class to the `Op` class after the concrete operation type: +Traits may be used when defining a derived object type, by simply appending the +name of the trait class to the end of the base object class operation type: ```c++ +/// Here we define 'MyAttr' along with the 'MyTrait' and `MyParametric trait +/// classes we defined previously. +class MyAttr : public Attribute::AttrBase::Impl> {}; /// Here we define 'MyOp' along with the 'MyTrait' and `MyParametric trait /// classes we defined previously. class MyOp : public Op::Impl> {}; +/// Here we define 'MyType' along with the 'MyTrait' and `MyParametric trait +/// classes we defined previously. +class MyType : public Type::TypeBase::Impl> {}; ``` -To use a trait in the [ODS](OpDefinitions.md) framework, we need to provide a -definition of the trait class. This can be done using the `NativeOpTrait` and -`ParamNativeOpTrait` classes. `ParamNativeOpTrait` provides a mechanism in which -to specify arguments to a parametric trait class with an internal `Impl`. +### Attaching Operation Traits in ODS + +To use an operation trait in the [ODS](OpDefinitions.md) framework, we need to +provide a definition of the trait class. This can be done using the +`NativeOpTrait` and `ParamNativeOpTrait` classes. `ParamNativeOpTrait` provides +a mechanism in which to specify arguments to a parametric trait class with an +internal `Impl`. ```tablegen // The argument is the c++ trait class name. @@ -110,14 +123,14 @@ ## Using a Trait Traits may be used to provide additional methods, static fields, or other -information directly on the concrete operation. `Traits` internally become -`Base` classes of the concrete operation, so all of these are directly -accessible. To expose this information opaquely to transformations and analyses, +information directly on the concrete object. `Traits` internally become `Base` +classes of the concrete operation, so all of these are directly accessible. To +expose this information opaquely to transformations and analyses, [`interfaces`](Interfaces.md) may be used. -To query if a specific operation contains a specific trait, the `hasTrait<>` -method may be used. This takes as a template parameter the trait class, which is -the same as the one passed when attaching the trait to an operation. +To query if a specific object contains a specific trait, the `hasTrait<>` method +may be used. This takes as a template parameter the trait class, which is the +same as the one passed when attaching the trait to an operation. ```c++ Operation *op = ..; @@ -125,7 +138,7 @@ ...; ``` -## Trait List +## Operation Traits List MLIR provides a suite of traits that provide various functionalities that are common across many different operations. Below is a list of some key traits that diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -199,11 +199,11 @@ "Operation *", "clone", (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{ BlockAndValueMapping map; - unsigned numRegions = op.getOperation()->getNumRegions(); - Operation *res = create(b, loc, operands, op.getAttrs()); + unsigned numRegions = $_op.getOperation()->getNumRegions(); + Operation *res = create(b, loc, operands, $_op.getAttrs()); assert(res->getNumRegions() == numRegions && "inconsistent # regions"); for (unsigned ridx = 0; ridx < numRegions; ++ridx) - op.getOperation()->getRegion(ridx).cloneInto( + $_op.getOperation()->getRegion(ridx).cloneInto( &res->getRegion(ridx), map); return res; }] diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1742,8 +1742,8 @@ } // This class represents a single, optionally static, interface method. -// Note: non-static interface methods have an implicit 'op' parameter -// corresponding to an instance of the derived operation. +// Note: non-static interface methods have an implicit parameter, either +// $_op/$_attr/$_type corresponding to an instance of the derived value. class InterfaceMethod { @@ -1773,8 +1773,8 @@ : InterfaceMethod; -// OpInterface represents an interface regarding an op. -class OpInterface : OpInterfaceTrait { +// Interface represents a base interface. +class Interface { // A human-readable description of what this interface does. string description = ""; @@ -1789,6 +1789,23 @@ code extraClassDeclaration = ""; } +// AttrInterface represents an interface registered to an attribute. +class AttrInterface : Interface { + // An optional code block containing extra declarations to place in the + // interface trait declaration. + code extraTraitClassDeclaration = ""; +} + +// OpInterface represents an interface registered to an operation. +class OpInterface : Interface, OpInterfaceTrait; + +// TypeInterface represents an interface registered to a type. +class TypeInterface : Interface { + // An optional code block containing extra declarations to place in the + // interface trait declaration. + code extraTraitClassDeclaration = ""; +} + // Whether to declare the op interface methods in the op's header. This class // simply wraps an OpInterface but is used to indicate that the method // declarations should be generated. This class takes an optional set of methods diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -33,7 +33,7 @@ "StringRef", "getName", (ins), [{ // Don't rely on the trait implementation as optional symbol operations // may override this. - return mlir::SymbolTable::getSymbolName(op); + return mlir::SymbolTable::getSymbolName($_op); }], /*defaultImplementation=*/[{ return mlir::SymbolTable::getSymbolName(this->getOperation()); }] diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -51,9 +51,9 @@ }], "Operation *", "resolveCallable", (ins), [{ // If the callable isn't a value, lookup the symbol reference. - CallInterfaceCallable callable = op.getCallableForCallee(); + CallInterfaceCallable callable = $_op.getCallableForCallee(); if (auto symbolRef = callable.dyn_cast()) - return SymbolTable::lookupNearestSymbolFrom(op, symbolRef); + return SymbolTable::lookupNearestSymbolFrom($_op, symbolRef); return callable.get().getDefiningOp(); }] >, diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -55,10 +55,10 @@ }], "Optional", "getSuccessorBlockArgument", (ins "unsigned":$operandIndex), [{ - Operation *opaqueOp = op; + Operation *opaqueOp = $_op; for (unsigned i = 0, e = opaqueOp->getNumSuccessors(); i != e; ++i) { if (Optional arg = detail::getBranchSuccessorArgument( - op.getSuccessorOperands(i), operandIndex, + $_op.getSuccessorOperands(i), operandIndex, opaqueOp->getSuccessor(i))) return arg; } diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td @@ -61,7 +61,7 @@ (ins "Value":$value, "SmallVectorImpl> &":$effects), [{ - op.getEffects(effects); + $_op.getEffects(effects); llvm::erase_if(effects, [&](auto &it) { return it.getValue() != value; }); @@ -75,7 +75,7 @@ (ins "SideEffects::Resource *":$resource, "SmallVectorImpl> &":$effects), [{ - op.getEffects(effects); + $_op.getEffects(effects); llvm::erase_if(effects, [&](auto &it) { return it.getResource() != resource; }); diff --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/Interfaces.h rename from mlir/include/mlir/TableGen/OpInterfaces.h rename to mlir/include/mlir/TableGen/Interfaces.h --- a/mlir/include/mlir/TableGen/OpInterfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -1,17 +1,13 @@ -//===- OpInterfaces.h - OpInterfaces wrapper class --------------*- C++ -*-===// +//===- Interfaces.h - Interface wrapper classes -----------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// OpInterfaces wrapper to simplify using TableGen OpInterfaces. -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_TABLEGEN_OPINTERFACES_H_ -#define MLIR_TABLEGEN_OPINTERFACES_H_ +#ifndef MLIR_TABLEGEN_INTERFACES_H_ +#define MLIR_TABLEGEN_INTERFACES_H_ #include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" @@ -25,9 +21,9 @@ namespace mlir { namespace tblgen { -// Wrapper class with helper methods for accessing OpInterfaceMethod defined +// Wrapper class with helper methods for accessing InterfaceMethod defined // in TableGen. -class OpInterfaceMethod { +class InterfaceMethod { public: // This struct represents a single method argument. struct Argument { @@ -35,7 +31,7 @@ StringRef name; }; - explicit OpInterfaceMethod(const llvm::Record *def); + explicit InterfaceMethod(const llvm::Record *def); // Return the return type of this method. StringRef getReturnType() const; @@ -68,20 +64,20 @@ }; //===----------------------------------------------------------------------===// -// OpInterface +// Interface //===----------------------------------------------------------------------===// -// Wrapper class with helper methods for accessing OpInterfaces defined in +// Wrapper class with helper methods for accessing Interfaces defined in // TableGen. -class OpInterface { +class Interface { public: - explicit OpInterface(const llvm::Record *def); + explicit Interface(const llvm::Record *def); // Return the name of this interface. StringRef getName() const; // Return the methods of this interface. - ArrayRef getMethods() const; + ArrayRef getMethods() const; // Return the description of this method if it has one. llvm::Optional getDescription() const; @@ -95,15 +91,36 @@ // Return the verify method body if it has one. llvm::Optional getVerify() const; + // Returns the Tablegen definition this interface was constructed from. + const llvm::Record &getDef() const { return *def; } + private: // The TableGen definition of this interface. const llvm::Record *def; // The methods of this interface. - SmallVector methods; + SmallVector methods; +}; + +// An interface that is registered to an Attribute. +struct AttrInterface : public Interface { + using Interface::Interface; + + static bool classof(const Interface *interface); }; +// An interface that is registered to an Operation. +struct OpInterface : public Interface { + using Interface::Interface; + static bool classof(const Interface *interface); +}; +// An interface that is registered to a Type. +struct TypeInterface : public Interface { + using Interface::Interface; + + static bool classof(const Interface *interface); +}; } // end namespace tblgen } // end namespace mlir -#endif // MLIR_TABLEGEN_OPINTERFACES_H_ +#endif // MLIR_TABLEGEN_INTERFACES_H_ diff --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp rename from mlir/lib/TableGen/OpInterfaces.cpp rename to mlir/lib/TableGen/Interfaces.cpp --- a/mlir/lib/TableGen/OpInterfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -1,16 +1,12 @@ -//===- OpInterfaces.cpp - OpInterfaces class ------------------------------===// +//===- Interfaces.cpp - Interface classes ---------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// OpInterfaces wrapper to simplify using TableGen OpInterfaces. -// -//===----------------------------------------------------------------------===// -#include "mlir/TableGen/OpInterfaces.h" +#include "mlir/TableGen/Interfaces.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" @@ -19,7 +15,11 @@ using namespace mlir; using namespace mlir::tblgen; -OpInterfaceMethod::OpInterfaceMethod(const llvm::Record *def) : def(def) { +//===----------------------------------------------------------------------===// +// InterfaceMethod +//===----------------------------------------------------------------------===// + +InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) { llvm::DagInit *args = def->getValueAsDag("arguments"); for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) { arguments.push_back( @@ -28,78 +28,112 @@ } } -StringRef OpInterfaceMethod::getReturnType() const { +StringRef InterfaceMethod::getReturnType() const { return def->getValueAsString("returnType"); } // Return the name of this method. -StringRef OpInterfaceMethod::getName() const { +StringRef InterfaceMethod::getName() const { return def->getValueAsString("name"); } // Return if this method is static. -bool OpInterfaceMethod::isStatic() const { +bool InterfaceMethod::isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); } // Return the body for this method if it has one. -llvm::Optional OpInterfaceMethod::getBody() const { +llvm::Optional InterfaceMethod::getBody() const { auto value = def->getValueAsString("body"); return value.empty() ? llvm::Optional() : value; } // Return the default implementation for this method if it has one. -llvm::Optional OpInterfaceMethod::getDefaultImplementation() const { +llvm::Optional InterfaceMethod::getDefaultImplementation() const { auto value = def->getValueAsString("defaultBody"); return value.empty() ? llvm::Optional() : value; } // Return the description of this method if it has one. -llvm::Optional OpInterfaceMethod::getDescription() const { +llvm::Optional InterfaceMethod::getDescription() const { auto value = def->getValueAsString("description"); return value.empty() ? llvm::Optional() : value; } -ArrayRef OpInterfaceMethod::getArguments() const { +ArrayRef InterfaceMethod::getArguments() const { return arguments; } -bool OpInterfaceMethod::arg_empty() const { return arguments.empty(); } +bool InterfaceMethod::arg_empty() const { return arguments.empty(); } + +//===----------------------------------------------------------------------===// +// Interface +//===----------------------------------------------------------------------===// + +Interface::Interface(const llvm::Record *def) : def(def) { + assert(def->isSubClassOf("Interface") && + "must be subclass of TableGen 'Interface' class"); -OpInterface::OpInterface(const llvm::Record *def) : def(def) { auto *listInit = dyn_cast(def->getValueInit("methods")); for (llvm::Init *init : listInit->getValues()) methods.emplace_back(cast(init)->getDef()); } // Return the name of this interface. -StringRef OpInterface::getName() const { +StringRef Interface::getName() const { return def->getValueAsString("cppClassName"); } // Return the methods of this interface. -ArrayRef OpInterface::getMethods() const { return methods; } +ArrayRef Interface::getMethods() const { return methods; } // Return the description of this method if it has one. -llvm::Optional OpInterface::getDescription() const { +llvm::Optional Interface::getDescription() const { auto value = def->getValueAsString("description"); return value.empty() ? llvm::Optional() : value; } // Return the interfaces extra class declaration code. -llvm::Optional OpInterface::getExtraClassDeclaration() const { +llvm::Optional Interface::getExtraClassDeclaration() const { auto value = def->getValueAsString("extraClassDeclaration"); return value.empty() ? llvm::Optional() : value; } // Return the traits extra class declaration code. -llvm::Optional OpInterface::getExtraTraitClassDeclaration() const { +llvm::Optional Interface::getExtraTraitClassDeclaration() const { auto value = def->getValueAsString("extraTraitClassDeclaration"); return value.empty() ? llvm::Optional() : value; } // Return the body for this method if it has one. -llvm::Optional OpInterface::getVerify() const { +llvm::Optional Interface::getVerify() const { + // Only OpInterface supports the verify method. + if (!isa(this)) + return llvm::None; auto value = def->getValueAsString("verify"); return value.empty() ? llvm::Optional() : value; } + +//===----------------------------------------------------------------------===// +// AttrInterface +//===----------------------------------------------------------------------===// + +bool AttrInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("AttrInterface"); +} + +//===----------------------------------------------------------------------===// +// OpInterface +//===----------------------------------------------------------------------===// + +bool OpInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("OpInterface"); +} + +//===----------------------------------------------------------------------===// +// TypeInterface +//===----------------------------------------------------------------------===// + +bool TypeInterface::classof(const Interface *interface) { + return interface->getDef().isSubClassOf("TypeInterface"); +} diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp --- a/mlir/lib/TableGen/OpTrait.cpp +++ b/mlir/lib/TableGen/OpTrait.cpp @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/OpTrait.h" -#include "mlir/TableGen/OpInterfaces.h" +#include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/Predicate.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -3,6 +3,11 @@ TestPatterns.cpp ) +set(LLVM_TARGET_DEFINITIONS TestInterfaces.td) +mlir_tablegen(TestTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRTestInterfaceIncGen) + set(LLVM_TARGET_DEFINITIONS TestOps.td) mlir_tablegen(TestOps.h.inc -gen-op-decls) mlir_tablegen(TestOps.cpp.inc -gen-op-defs) @@ -22,6 +27,7 @@ EXCLUDE_FROM_LIBMLIR DEPENDS + MLIRTestInterfaceIncGen MLIRTestOpsIncGen LINK_LIBS PUBLIC diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" @@ -135,9 +137,21 @@ >(); addInterfaces(); + addTypes(); allowUnknownOperations(); } +Type TestDialect::parseType(DialectAsmParser &parser) const { + if (failed(parser.parseKeyword("test_type"))) + return Type(); + return TestType::get(getContext()); +} + +void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { + assert(type.isa() && "unexpected type"); + printer << "test_type"; +} + LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") @@ -527,6 +541,7 @@ #include "TestOpEnums.cpp.inc" #include "TestOpStructs.cpp.inc" +#include "TestTypeInterfaces.cpp.inc" #define GET_OP_CLASSES #include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -0,0 +1,46 @@ +//===-- TestInterfaces.td - Test dialect interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_INTERFACES +#define TEST_INTERFACES + +include "mlir/IR/OpBase.td" + +// A type interface used to test the ODS generation of type interfaces. +def TestTypeInterface : TypeInterface<"TestTypeInterface"> { + let methods = [ + InterfaceMethod<"Prints the type name.", + "void", "printTypeA", (ins "Location":$loc), [{ + emitRemark(loc) << $_type << " - TestA"; + }] + >, + InterfaceMethod<"Prints the type name.", + "void", "printTypeB", (ins "Location":$loc), + [{}], /*defaultImplementation=*/[{ + emitRemark(loc) << $_type << " - TestB"; + }] + >, + InterfaceMethod<"Prints the type name.", + "void", "printTypeC", (ins "Location":$loc) + >, + ]; + let extraClassDeclaration = [{ + /// Prints the type name. + void printTypeD(Location loc) const { + emitRemark(loc) << *this << " - TestD"; + } + }]; + let extraTraitClassDeclaration = [{ + /// Prints the type name. + void printTypeE(Location loc) const { + emitRemark(loc) << $_type << " - TestE"; + } + }]; +} + +#endif // TEST_INTERFACES diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -90,6 +90,10 @@ ); } +def TEST_TestType : DialectType()">, "test">, + BuildableType<"$_builder.getType<::mlir::TestType>()">; + //===----------------------------------------------------------------------===// // Test Symbols //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -0,0 +1,44 @@ +//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains types defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTTYPES_H +#define MLIR_TESTTYPES_H + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +#include "TestTypeInterfaces.h.inc" + +/// This class is a simple test type that uses a generated interface. +struct TestType : public Type::TypeBase { + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE; + } + + static TestType get(MLIRContext *context) { + return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE); + } + + /// Provide a definition for the necessary interface methods. + void printTypeC(Location loc) const { + emitRemark(loc) << *this << " - TestC"; + } +}; +} // end namespace mlir + +#endif // MLIR_TESTTYPES_H diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR TestFunc.cpp + TestInterfaces.cpp TestMatchers.cpp TestSideEffects.cpp TestSymbolUses.cpp diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestInterfaces.cpp @@ -0,0 +1,41 @@ +//===- TestInterfaces.cpp - Test interface generation and application -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This test checks various aspects of Type interface generation and +/// application. +struct TestTypeInterfaces + : public PassWrapper> { + void runOnOperation() override { + getOperation().walk([](Operation *op) { + for (Type type : op->getResultTypes()) { + if (auto testInterface = type.dyn_cast()) { + testInterface.printTypeA(op->getLoc()); + testInterface.printTypeB(op->getLoc()); + testInterface.printTypeC(op->getLoc()); + testInterface.printTypeD(op->getLoc()); + } + if (auto testType = type.dyn_cast()) + testType.printTypeE(op->getLoc()); + } + }); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTestInterfaces() { + PassRegistration pass("test-type-interfaces", + "Test type interface support."); +} +} // namespace mlir diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/interfaces.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt -test-type-interfaces -allow-unregistered-dialect -verify-diagnostics %s + +// expected-remark@below {{'!test.test_type' - TestA}} +// expected-remark@below {{'!test.test_type' - TestB}} +// expected-remark@below {{'!test.test_type' - TestC}} +// expected-remark@below {{'!test.test_type' - TestD}} +// expected-remark@below {{'!test.test_type' - TestE}} +%foo0 = "foo.test"() : () -> (!test.test_type) + +// Type without the test interface. +%foo1 = "foo.test"() : () -> (i32) diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -51,6 +51,7 @@ void registerTestExpandTanhPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); +void registerTestInterfaces(); void registerTestLinalgHoisting(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); @@ -125,6 +126,7 @@ registerTestFunc(); registerTestExpandTanhPass(); registerTestGpuMemoryPromotionPass(); + registerTestInterfaces(); registerTestLinalgHoisting(); registerTestLinalgTransforms(); registerTestLivenessPass(); diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -12,8 +12,8 @@ #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/OpClass.h" -#include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/Sequence.h" diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -14,8 +14,8 @@ #include "OpFormatGen.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/OpClass.h" -#include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/SideEffects.h" @@ -1444,7 +1444,7 @@ alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), alwaysDeclaredMethodsVec.end()); - for (const OpInterfaceMethod &method : interface.getMethods()) { + for (const InterfaceMethod &method : interface.getMethods()) { // Don't declare if the method has a body. if (method.getBody()) continue; @@ -1457,7 +1457,7 @@ std::string args; llvm::raw_string_ostream os(args); interleaveComma(method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { + [&](const InterfaceMethod::Argument &arg) { os << arg.type << " " << arg.name; }); opClass.newMethod(method.getReturnType(), method.getName(), os.str(), diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -10,8 +10,8 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/OpClass.h" -#include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/MapVector.h" diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -13,7 +13,7 @@ #include "DocGenUtilities.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" -#include "mlir/TableGen/OpInterfaces.h" +#include "mlir/TableGen/Interfaces.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -22,70 +22,160 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -using namespace llvm; using namespace mlir; +using mlir::tblgen::Interface; +using mlir::tblgen::InterfaceMethod; using mlir::tblgen::OpInterface; -using mlir::tblgen::OpInterfaceMethod; -// Emit the method name and argument list for the given method. If -// 'addOperationArg' is true, then an Operation* argument is added to the -// beginning of the argument list. -static void emitMethodNameAndArgs(const OpInterfaceMethod &method, - raw_ostream &os, bool addOperationArg) { +/// Emit a string corresponding to a C++ type, followed by a space if necessary. +static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { + type = type.trim(); + os << type; + if (type.back() != '&' && type.back() != '*') + os << " "; + return os; +} + +/// Emit the method name and argument list for the given method. If 'addThisArg' +/// is true, then an argument is added to the beginning of the argument list for +/// the concrete value. +static void emitMethodNameAndArgs(const InterfaceMethod &method, + raw_ostream &os, StringRef valueType, + bool addThisArg, bool addConst) { os << method.getName() << '('; - if (addOperationArg) - os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", "); + if (addThisArg) + emitCPPType(valueType, os) + << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); llvm::interleaveComma(method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { + [&](const InterfaceMethod::Argument &arg) { os << arg.type << " " << arg.name; }); os << ')'; + if (addConst) + os << " const"; } -// Get an array of all OpInterface definitions but exclude those subclassing -// "DeclareOpInterfaceMethods". -static std::vector -getAllOpInterfaceDefinitions(const RecordKeeper &recordKeeper) { - std::vector defs = +/// Get an array of all OpInterface definitions but exclude those subclassing +/// "DeclareOpInterfaceMethods". +static std::vector +getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) { + std::vector defs = recordKeeper.getAllDerivedDefinitions("OpInterface"); - llvm::erase_if(defs, [](const Record *def) { + llvm::erase_if(defs, [](const llvm::Record *def) { return def->isSubClassOf("DeclareOpInterfaceMethods"); }); return defs; } +namespace { +/// This struct is the base generator used when processing tablegen interfaces. +class InterfaceGenerator { +public: + bool emitInterfaceDefs(); + bool emitInterfaceDecls(); + bool emitInterfaceDocs(); + +protected: + InterfaceGenerator(std::vector &&defs, raw_ostream &os) + : defs(std::move(defs)), os(os) {} + + void emitConceptDecl(Interface &interface); + void emitModelDecl(Interface &interface); + void emitTraitDecl(Interface &interface, StringRef interfaceName, + StringRef interfaceTraitsName); + void emitInterfaceDecl(Interface interface); + + /// The set of interface records to emit. + std::vector defs; + // The stream to emit to. + raw_ostream &os; + /// The C++ value type of the interface, e.g. Operation*. + StringRef valueType; + /// The C++ base interface type. + StringRef interfaceBaseType; + /// The name of the typename for the value template. + StringRef valueTemplate; + /// The format context to use for methods. + tblgen::FmtContext nonStaticMethodFmt; + tblgen::FmtContext traitMethodFmt; +}; + +/// A specialized generator for attribute interfaces. +struct AttrInterfaceGenerator : public InterfaceGenerator { + AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"), + os) { + valueType = "Attribute"; + interfaceBaseType = "AttrInterface"; + valueTemplate = "ConcreteAttr"; + StringRef castCode = "(tablegen_opaque_val.cast())"; + nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode); + traitMethodFmt.addSubst("_attr", + "(*static_cast(this))"); + } +}; +/// A specialized generator for operaton interfaces. +struct OpInterfaceGenerator : public InterfaceGenerator { + OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) { + valueType = "Operation *"; + interfaceBaseType = "OpInterface"; + valueTemplate = "ConcreteOp"; + StringRef castCode = "(llvm::cast(tablegen_opaque_val))"; + nonStaticMethodFmt.withOp(castCode).withSelf(castCode); + traitMethodFmt.withOp("(*static_cast(this))"); + } +}; +/// A specialized generator for type interfaces. +struct TypeInterfaceGenerator : public InterfaceGenerator { + TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"), + os) { + valueType = "Type"; + interfaceBaseType = "TypeInterface"; + valueTemplate = "ConcreteType"; + StringRef castCode = "(tablegen_opaque_val.cast())"; + nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode); + traitMethodFmt.addSubst("_type", + "(*static_cast(this))"); + } +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // GEN: Interface definitions //===----------------------------------------------------------------------===// -static void emitInterfaceDef(OpInterface &interface, raw_ostream &os) { +static void emitInterfaceDef(Interface interface, StringRef valueType, + raw_ostream &os) { StringRef interfaceName = interface.getName(); // Insert the method definitions. + bool isOpInterface = isa(interface); for (auto &method : interface.getMethods()) { - os << method.getReturnType() << " " << interfaceName << "::"; - emitMethodNameAndArgs(method, os, /*addOperationArg=*/false); + emitCPPType(method.getReturnType(), os) << interfaceName << "::"; + emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + /*addConst=*/!isOpInterface); // Forward to the method on the concrete operation type. os << " {\n return getImpl()->" << method.getName() << '('; - if (!method.isStatic()) - os << "getOperation()" << (method.arg_empty() ? "" : ", "); + if (!method.isStatic()) { + os << (isOpInterface ? "getOperation()" : "*this"); + os << (method.arg_empty() ? "" : ", "); + } llvm::interleaveComma( method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; }); + [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n }\n"; } } -static bool emitInterfaceDefs(const RecordKeeper &recordKeeper, - raw_ostream &os) { - llvm::emitSourceFileHeader("Operation Interface Definitions", os); +bool InterfaceGenerator::emitInterfaceDefs() { + llvm::emitSourceFileHeader("Interface Definitions", os); - for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper)) { - OpInterface interface(def); - emitInterfaceDef(interface, os); - } + for (const auto *def : defs) + emitInterfaceDef(Interface(def), valueType, os); return false; } @@ -93,122 +183,135 @@ // GEN: Interface declarations //===----------------------------------------------------------------------===// -static void emitConceptDecl(OpInterface &interface, raw_ostream &os) { +void InterfaceGenerator::emitConceptDecl(Interface &interface) { os << " class Concept {\n" << " public:\n" << " virtual ~Concept() = default;\n"; // Insert each of the pure virtual concept methods. for (auto &method : interface.getMethods()) { - os << " virtual " << method.getReturnType() << " "; - emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic()); + os << " virtual "; + emitCPPType(method.getReturnType(), os); + emitMethodNameAndArgs(method, os, valueType, + /*addThisArg=*/!method.isStatic(), /*addConst=*/true); os << " = 0;\n"; } os << " };\n"; } -static void emitModelDecl(OpInterface &interface, raw_ostream &os) { - os << " template\n"; - os << " class Model : public Concept {\npublic:\n"; +void InterfaceGenerator::emitModelDecl(Interface &interface) { + os << " template\n"; + os << " class Model : public Concept {\n public:\n"; // Insert each of the virtual method overrides. for (auto &method : interface.getMethods()) { - os << " " << method.getReturnType() << " "; - emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic()); - os << " final {\n"; - - // Provide a definition of the concrete op if this is non static. - if (!method.isStatic()) { - os << " auto op = llvm::cast(tablegen_opaque_op);\n" - << " (void)op;\n"; - } + emitCPPType(method.getReturnType(), os << " "); + emitMethodNameAndArgs(method, os, valueType, + /*addThisArg=*/!method.isStatic(), /*addConst=*/true); + os << " final {\n "; // Check for a provided body to the function. - if (auto body = method.getBody()) { - os << body << "\n }\n"; + if (Optional body = method.getBody()) { + if (method.isStatic()) + os << body->trim(); + else + os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt); + os << "\n }\n"; continue; } // Forward to the method on the concrete operation type. - os << " return " << (method.isStatic() ? "ConcreteOp::" : "op."); + if (method.isStatic()) + os << "return " << valueTemplate << "::"; + else + os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt); // Add the arguments to the call. os << method.getName() << '('; llvm::interleaveComma( method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { os << arg.name; }); + [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n }\n"; } os << " };\n"; } -static void emitTraitDecl(OpInterface &interface, raw_ostream &os, - StringRef interfaceName, - StringRef interfaceTraitsName) { - os << " template \n " - << llvm::formatv("struct {0}Trait : public OpInterface<{0}," - " detail::{1}>::Trait {{\n", - interfaceName, interfaceTraitsName); +void InterfaceGenerator::emitTraitDecl(Interface &interface, + StringRef interfaceName, + StringRef interfaceTraitsName) { + os << llvm::formatv(" template \n" + " struct {0}Trait : public {2}<{0}," + " detail::{1}>::Trait<{3}> {{\n", + interfaceName, interfaceTraitsName, interfaceBaseType, + valueTemplate); // Insert the default implementation for any methods. + bool isOpInterface = isa(interface); for (auto &method : interface.getMethods()) { // Flag interface methods named verifyTrait. if (method.getName() == "verifyTrait") PrintFatalError( formatv("'verifyTrait' method cannot be specified as interface " - "method for '{0}'; set 'verify' on OpInterfaceTrait instead", + "method for '{0}'; use the 'verify' field instead", interfaceName)); auto defaultImpl = method.getDefaultImplementation(); if (!defaultImpl) continue; - os << " " << (method.isStatic() ? "static " : "") << method.getReturnType() - << " "; - emitMethodNameAndArgs(method, os, /*addOperationArg=*/false); - os << " {\n" << defaultImpl.getValue() << " }\n"; + os << " " << (method.isStatic() ? "static " : ""); + emitCPPType(method.getReturnType(), os); + emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + /*addConst=*/!isOpInterface); + os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt) + << "\n }\n"; } - tblgen::FmtContext traitCtx; - traitCtx.withOp("op"); if (auto verify = interface.getVerify()) { - os << " static LogicalResult verifyTrait(Operation* op) {\n" - << std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n"; + assert(isa(interface) && "only OpInterfaces support 'verify'"); + + tblgen::FmtContext verifyCtx; + verifyCtx.withOp("op"); + os << " static LogicalResult verifyTrait(Operation *op) {\n " + << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n"; } if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) - os << extraTraitDecls << "\n"; + os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; // Emit a utility wrapper trait class. - os << " template \n " - << llvm::formatv("struct Trait : public {0}Trait {{};\n", - interfaceName); + os << llvm::formatv(" template \n" + " struct Trait : public {0}Trait<{1}> {{};\n", + interfaceName, valueTemplate); } -static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) { +void InterfaceGenerator::emitInterfaceDecl(Interface interface) { StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); // Emit the traits struct containing the concept and model declarations. os << "namespace detail {\n" << "struct " << interfaceTraitsName << " {\n"; - emitConceptDecl(interface, os); - emitModelDecl(interface, os); + emitConceptDecl(interface); + emitModelDecl(interface); os << "};\n} // end namespace detail\n"; // Emit the main interface class declaration. - os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n" + os << llvm::formatv("class {0} : public {3}<{1}, detail::{2}> {\n" "public:\n" - " using OpInterface<{1}, detail::{2}>::OpInterface;\n", - interfaceName, interfaceName, interfaceTraitsName); + " using {3}<{1}, detail::{2}>::{3};\n", + interfaceName, interfaceName, interfaceTraitsName, + interfaceBaseType); // Emit the derived trait for the interface. - emitTraitDecl(interface, os, interfaceName, interfaceTraitsName); + emitTraitDecl(interface, interfaceName, interfaceTraitsName); // Insert the method declarations. + bool isOpInterface = isa(interface); for (auto &method : interface.getMethods()) { - os << " " << method.getReturnType() << " "; - emitMethodNameAndArgs(method, os, /*addOperationArg=*/false); + emitCPPType(method.getReturnType(), os << " "); + emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + /*addConst=*/!isOpInterface); os << ";\n"; } @@ -219,14 +322,11 @@ os << "};\n"; } -static bool emitInterfaceDecls(const RecordKeeper &recordKeeper, - raw_ostream &os) { - llvm::emitSourceFileHeader("Operation Interface Declarations", os); +bool InterfaceGenerator::emitInterfaceDecls() { + llvm::emitSourceFileHeader("Interface Declarations", os); - for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper)) { - OpInterface interface(def); - emitInterfaceDecl(interface, os); - } + for (const auto *def : defs) + emitInterfaceDecl(Interface(def)); return false; } @@ -234,17 +334,9 @@ // GEN: Interface documentation //===----------------------------------------------------------------------===// -/// Emit a string corresponding to a C++ type, followed by a space if necessary. -static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { - type = type.trim(); - os << type; - if (type.back() != '&' && type.back() != '*') - os << " "; - return os; -} - -static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) { - OpInterface interface(&interfaceDef); +static void emitInterfaceDoc(const llvm::Record &interfaceDef, + raw_ostream &os) { + Interface interface(&interfaceDef); // Emit the interface name followed by the description. os << "## " << interface.getName() << " (" << interfaceDef.getName() << ")"; @@ -262,7 +354,7 @@ os << "static "; emitCPPType(method.getReturnType(), os) << method.getName() << '('; llvm::interleaveComma(method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { + [&](const InterfaceMethod::Argument &arg) { emitCPPType(arg.type, os) << arg.name; }); os << ");\n```\n"; @@ -271,19 +363,17 @@ if (auto description = method.getDescription()) mlir::tblgen::emitDescription(*description, os); - // If the body is not provided, this method must be provided by the - // operation. + // If the body is not provided, this method must be provided by the user. if (!method.getBody()) - os << "\nNOTE: This method *must* be implemented by the operation.\n\n"; + os << "\nNOTE: This method *must* be implemented by the user.\n\n"; } } -static bool emitInterfaceDocs(const RecordKeeper &recordKeeper, - raw_ostream &os) { +bool InterfaceGenerator::emitInterfaceDocs() { os << "\n"; - os << "# Operation Interface definition\n"; + os << "# " << interfaceBaseType << " definitions\n"; - for (const auto *def : getAllOpInterfaceDefinitions(recordKeeper)) + for (const auto *def : defs) emitInterfaceDoc(*def, os); return false; } @@ -292,26 +382,31 @@ // GEN: Interface registration hooks //===----------------------------------------------------------------------===// -// Registers the operation interface generator to mlir-tblgen. -static mlir::GenRegistration - genInterfaceDecls("gen-op-interface-decls", - "Generate op interface declarations", - [](const RecordKeeper &records, raw_ostream &os) { - return emitInterfaceDecls(records, os); - }); - -// Registers the operation interface generator to mlir-tblgen. -static mlir::GenRegistration - genInterfaceDefs("gen-op-interface-defs", - "Generate op interface definitions", - [](const RecordKeeper &records, raw_ostream &os) { - return emitInterfaceDefs(records, os); - }); - -// Registers the operation interface document generator to mlir-tblgen. -static mlir::GenRegistration - genInterfaceDocs("gen-op-interface-doc", - "Generate op interface documentation", - [](const RecordKeeper &records, raw_ostream &os) { - return emitInterfaceDocs(records, os); - }); +namespace { +template +struct InterfaceGenRegistration { + InterfaceGenRegistration(StringRef genArg) + : genDeclArg(("gen-" + genArg + "-interface-decls").str()), + genDefArg(("gen-" + genArg + "-interface-defs").str()), + genDocArg(("gen-" + genArg + "-interface-docs").str()), + genDecls(genDeclArg, "Generate interface declarations", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return GeneratorT(records, os).emitInterfaceDecls(); + }), + genDefs(genDefArg, "Generate interface definitions", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return GeneratorT(records, os).emitInterfaceDefs(); + }), + genDocs(genDocArg, "Generate interface documentation", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return GeneratorT(records, os).emitInterfaceDocs(); + }) {} + + std::string genDeclArg, genDefArg, genDocArg; + mlir::GenRegistration genDecls, genDefs, genDocs; +}; +} // end anonymous namespace + +static InterfaceGenRegistration attrGen("attr"); +static InterfaceGenRegistration opGen("op"); +static InterfaceGenRegistration typeGen("type"); diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -205,7 +205,8 @@ << " public:\n" << " virtual ~Concept() = default;\n" << " virtual " << availability.getQueryFnRetType() << " " - << availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n" + << availability.getQueryFnName() + << "(Operation *tblgen_opaque_op) const = 0;\n" << " };\n"; } @@ -215,7 +216,7 @@ << " public:\n" << " " << availability.getQueryFnRetType() << " " << availability.getQueryFnName() - << "(Operation *tblgen_opaque_op) final {\n" + << "(Operation *tblgen_opaque_op) const final {\n" << " auto op = llvm::cast(tblgen_opaque_op);\n" << " (void)op;\n" // Forward to the method on the concrete operation type.