diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -393,6 +393,18 @@ These declarations are _not_ implicitly visible in default implementations of interface methods, but static declarations may be accessed with full name qualification. +* Extra Shared Class Declarations (Optional: `extraSharedClassDeclaration`) + - Additional C++ code that is injected into the declarations of both the + interface and trait class. This allows for defining methods and more + that are exposed on both the interface and trait class, e.g. to inject + utilties on both the interface and the derived entity implementing the + interface (e.g. attribute, operation, etc.). + - In non-static methods, `$_attr`/`$_op`/`$_type` + (depending on the type of interface) may be used to refer to an + instance of the IR entity. In the interface declaration, the type of + the instance is the interface class. In the trait declaration, the + type of the instance is the concrete entity class + (e.g. `IntegerAttr`, `FuncOp`, etc.). `OpInterface` classes may additionally contain the following: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2226,6 +2226,10 @@ // An optional code block containing extra declarations to place in the // interface declaration. code extraClassDeclaration = ""; + + // An optional code block containing extra declarations to place in both + // the interface and trait declaration. + code extraSharedClassDeclaration = ""; } // AttrInterface represents an interface registered to an attribute. diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -91,6 +91,10 @@ // Return the traits extra class declaration code. llvm::Optional getExtraTraitClassDeclaration() const; + // Return the extra class declaration code shared between the interface and + // trait classes. + llvm::Optional getExtraSharedClassDeclaration() const; + // Return the verify method body if it has one. llvm::Optional getVerify() const; diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -110,6 +110,12 @@ return value.empty() ? llvm::Optional() : value; } +// Return the shared extra class declaration code. +llvm::Optional Interface::getExtraSharedClassDeclaration() const { + auto value = def->getValueAsString("extraSharedClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} + // Return the body for this method if it has one. llvm::Optional Interface::getVerify() const { // Only OpInterface supports the verify method. diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -3,6 +3,24 @@ include "mlir/IR/OpBase.td" +def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> { + let extraSharedClassDeclaration = [{ + bool sharedMethodDeclaration() { + return $_op.someOtherMethod(); + } + }]; +} + +// DECL: class ExtraShardDeclsInterface +// DECL: bool sharedMethodDeclaration() { +// DECL-NEXT: return (*this).someOtherMethod(); +// DECL-NEXT: } + +// DECL: struct ExtraShardDeclsInterfaceTrait +// DECL: bool sharedMethodDeclaration() { +// DECL-NEXT: return (*static_cast(this)).someOtherMethod(); +// DECL-NEXT: } + def TestOpInterface : OpInterface<"TestOpInterface"> { let description = [{some op interface description}]; diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -104,6 +104,7 @@ /// The format context to use for methods. tblgen::FmtContext nonStaticMethodFmt; tblgen::FmtContext traitMethodFmt; + tblgen::FmtContext extraDeclsFmt; }; /// A specialized generator for attribute interfaces. @@ -118,6 +119,7 @@ nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode); traitMethodFmt.addSubst("_attr", "(*static_cast(this))"); + extraDeclsFmt.addSubst("_attr", "(*this)"); } }; /// A specialized generator for operation interfaces. @@ -132,6 +134,7 @@ .withOp(castCode) .withSelf(castCode); traitMethodFmt.withOp("(*static_cast(this))"); + extraDeclsFmt.withOp("(*this)"); } }; /// A specialized generator for type interfaces. @@ -146,6 +149,7 @@ nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode); traitMethodFmt.addSubst("_type", "(*static_cast(this))"); + extraDeclsFmt.addSubst("_type", "(*this)"); } }; } // namespace @@ -415,6 +419,8 @@ } if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; + if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration()) + os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; } @@ -469,6 +475,9 @@ // Emit any extra declarations. if (Optional extraDecls = interface.getExtraClassDeclaration()) os << *extraDecls << "\n"; + if (Optional extraDecls = + interface.getExtraSharedClassDeclaration()) + os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt); os << "};\n";