diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1503,6 +1503,10 @@ // The list of methods defined by this interface. list methods = []; + + // An optional code block containing extra declarations to place in the + // interface declaration. + code extraClassDeclaration = ""; } // Whether to declare the op interface methods in the op's header. This class diff --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/OpInterfaces.h --- a/mlir/include/mlir/TableGen/OpInterfaces.h +++ b/mlir/include/mlir/TableGen/OpInterfaces.h @@ -86,6 +86,9 @@ // Return the description of this method if it has one. llvm::Optional getDescription() const; + // Return the interfaces extra class declaration code. + llvm::Optional getExtraClassDeclaration() const; + // Return the verify method body if it has one. llvm::Optional getVerify() const; diff --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/OpInterfaces.cpp --- a/mlir/lib/TableGen/OpInterfaces.cpp +++ b/mlir/lib/TableGen/OpInterfaces.cpp @@ -86,6 +86,12 @@ return value.empty() ? llvm::Optional() : value; } +// Return the interfaces extra class declaration code. +llvm::Optional OpInterface::getExtraClassDeclaration() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} + // Return the body for this method if it has one. llvm::Optional OpInterface::getVerify() const { auto value = def->getValueAsString("verify"); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -205,6 +205,11 @@ emitMethodNameAndArgs(method, os, /*addOperationArg=*/false); os << ";\n"; } + + // Emit any extra declarations. + if (Optional extraDecls = interface.getExtraClassDeclaration()) + os << *extraDecls << "\n"; + os << "};\n"; }