diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2049,6 +2049,12 @@ // An optional code block containing extra declarations to place in both // the interface and trait declaration. code extraSharedClassDeclaration = ""; + + // An optional code block for adding additional "classof" logic. This can + // be used to better enable "optional" interfaces, where an entity only + // implements the interface if some dynamic characteristic holds. `$_self` + // may be used to refer to an instance of this interface being checked. + code extraClassOf = ""; } // AttrInterface represents an interface registered to an attribute. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -174,28 +174,7 @@ return success(); }]; - let extraClassDeclaration = [{ - /// Convenience version of `getNameAttr` that returns a StringRef. - StringRef getName() { - return getNameAttr().getValue(); - } - - /// Convenience version of `setName` that take a StringRef. - void setName(StringRef name) { - setName(StringAttr::get(this->getContext(), name)); - } - - /// Custom classof that handles the case where the symbol is optional. - static bool classof(Operation *op) { - auto *opConcept = getInterfaceFor(op); - if (!opConcept) - return false; - return !opConcept->isOptionalSymbol(opConcept, op) || - op->getAttr(::mlir::SymbolTable::getSymbolAttrName()); - } - }]; - - let extraTraitClassDeclaration = [{ + let extraSharedClassDeclaration = [{ using Visibility = mlir::SymbolTable::Visibility; /// Convenience version of `getNameAttr` that returns a StringRef. @@ -208,6 +187,12 @@ setName(StringAttr::get($_op->getContext(), name)); } }]; + + // Add additional classof checks to properly handle "optional" symbols. + let extraClassOf = [{ + return !$_self.isOptionalSymbol() || + $_self->getAttr(::mlir::SymbolTable::getSymbolAttrName()); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -110,6 +110,12 @@ "expected value to provide interface instance"); } + /// Constructor for a known concept. + Interface(ValueT t, Concept *conceptImpl) + : BaseType(t), conceptImpl(conceptImpl) { + assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl); + } + /// Constructor for DenseMapInfo's empty key and tombstone key. Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {} diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -95,6 +95,9 @@ // trait classes. std::optional getExtraSharedClassDeclaration() const; + // Return the extra classof method code. + std::optional getExtraClassOf() const; + // Return the verify method body if it has one. std::optional getVerify() const; diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -116,6 +116,11 @@ return value.empty() ? std::optional() : value; } +std::optional Interface::getExtraClassOf() const { + auto value = def->getValueAsString("extraClassOf"); + return value.empty() ? std::optional() : value; +} + // Return the body for this method if it has one. std::optional Interface::getVerify() const { // Only OpInterface supports the verify method. diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -3,6 +3,20 @@ include "mlir/IR/OpBase.td" +def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> { + let extraClassOf = "return $_self.someOtherMethod();"; +} + +// DECL: class ExtraClassOfInterface +// DECL: static bool classof(::mlir::Operation * base) { +// DECL-NEXT: auto *concept = getInterfaceFor(base); +// DECL-NEXT: if (!concept) +// DECL-NEXT: return false; +// DECL-NEXT: ExtraClassOfInterface iface(base, concept); +// DECL-NEXT: (void)iface; +// DECL-NEXT: return iface.someOtherMethod(); +// DECL-NEXT: } + def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> { let extraSharedClassDeclaration = [{ bool sharedMethodDeclaration() { diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -508,6 +508,20 @@ interface.getExtraSharedClassDeclaration()) os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt); + // Emit classof code if necessary. + if (std::optional extraClassOf = interface.getExtraClassOf()) { + auto extraClassOfFmt = tblgen::FmtContext(); + extraClassOfFmt.withSelf("iface"); + os << " static bool classof(" << valueType << " base) {\n" + << " auto *concept = getInterfaceFor(base);\n" + " if (!concept)\n" + " return false;\n" + << llvm::formatv(" {0} iface(base, concept);\n", interfaceName) + << " (void)iface;\n" + << " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt) + << "\n }\n"; + } + os << "};\n"; os << "namespace detail {\n";