diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2048,6 +2048,13 @@ // An optional code block containing extra declarations to place in both // the interface and trait declaration. code extraSharedClassDeclaration = ""; + + // An optional code block for adding additional "classof" logic. This can + // be used to better enable "optional" interfaces, where an entity only + // implements the interface if some dynamic characteristic holds. + // `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the + // entity being checked. + code extraClassOf = ""; } // AttrInterface represents an interface registered to an attribute. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -174,28 +174,7 @@ return success(); }]; - let extraClassDeclaration = [{ - /// Convenience version of `getNameAttr` that returns a StringRef. - StringRef getName() { - return getNameAttr().getValue(); - } - - /// Convenience version of `setName` that take a StringRef. - void setName(StringRef name) { - setName(StringAttr::get(this->getContext(), name)); - } - - /// Custom classof that handles the case where the symbol is optional. - static bool classof(Operation *op) { - auto *opConcept = getInterfaceFor(op); - if (!opConcept) - return false; - return !opConcept->isOptionalSymbol(opConcept, op) || - op->getAttr(::mlir::SymbolTable::getSymbolAttrName()); - } - }]; - - let extraTraitClassDeclaration = [{ + let extraSharedClassDeclaration = [{ using Visibility = mlir::SymbolTable::Visibility; /// Convenience version of `getNameAttr` that returns a StringRef. @@ -208,6 +187,11 @@ setName(StringAttr::get($_op->getContext(), name)); } }]; + + // Add additional classof checks to properly handle "optional" symbols. + let extraClassOf = [{ + return $_op->hasAttr(::mlir::SymbolTable::getSymbolAttrName()); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -110,6 +110,12 @@ "expected value to provide interface instance"); } + /// Constructor for a known concept. + Interface(ValueT t, Concept *conceptImpl) + : BaseType(t), conceptImpl(conceptImpl) { + assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl); + } + /// Constructor for DenseMapInfo's empty key and tombstone key. Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {} diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h --- a/mlir/include/mlir/TableGen/Format.h +++ b/mlir/include/mlir/TableGen/Format.h @@ -44,7 +44,6 @@ None, Custom, // For custom placeholders Builder, // For the $_builder placeholder - Op, // For the $_op placeholder Self, // For the $_self placeholder }; @@ -58,7 +57,6 @@ // Setters for builtin placeholders FmtContext &withBuilder(Twine subst); - FmtContext &withOp(Twine subst); FmtContext &withSelf(Twine subst); std::optional getSubstFor(PHKind placeholder) const; diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -95,6 +95,9 @@ // trait classes. std::optional getExtraSharedClassDeclaration() const; + // Return the extra classof method code. + std::optional getExtraClassOf() const; + // Return the verify method body if it has one. std::optional getVerify() const; diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -190,7 +190,7 @@ const ConstraintMap &constraints, StringRef selfName, const char *const codeTemplate) { FmtContext ctx; - ctx.withOp("*op").withSelf(selfName); + ctx.addSubst("_op", "*op").withSelf(selfName); for (auto &it : constraints) { os << formatv(codeTemplate, it.second, tgfmt(it.first.getConditionTemplate(), &ctx), @@ -216,7 +216,7 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() { FmtContext ctx; - ctx.withOp("*op").withBuilder("rewriter").withSelf("type"); + ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type"); for (auto &it : typeConstraints) { os << formatv(patternAttrOrTypeConstraintCode, it.second, tgfmt(it.first.getConditionTemplate(), &ctx), @@ -240,9 +240,9 @@ /// because ops use cached identifiers. static bool canUniqueAttrConstraint(Attribute attr) { FmtContext ctx; - auto test = - tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op")) - .str(); + auto test = tgfmt(attr.getConditionTemplate(), + &ctx.withSelf("attr").addSubst("_op", "*op")) + .str(); return !StringRef(test).contains(""); } diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp --- a/mlir/lib/TableGen/Format.cpp +++ b/mlir/lib/TableGen/Format.cpp @@ -38,11 +38,6 @@ return *this; } -FmtContext &FmtContext::withOp(Twine subst) { - builtinSubstMap[PHKind::Op] = subst.str(); - return *this; -} - FmtContext &FmtContext::withSelf(Twine subst) { builtinSubstMap[PHKind::Self] = subst.str(); return *this; @@ -69,7 +64,6 @@ FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) { return StringSwitch(str) .Case("_builder", FmtContext::PHKind::Builder) - .Case("_op", FmtContext::PHKind::Op) .Case("_self", FmtContext::PHKind::Self) .Case("", FmtContext::PHKind::None) .Default(FmtContext::PHKind::Custom); diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -116,6 +116,11 @@ return value.empty() ? std::optional() : value; } +std::optional Interface::getExtraClassOf() const { + auto value = def->getValueAsString("extraClassOf"); + return value.empty() ? std::optional() : value; +} + // Return the body for this method if it has one. std::optional Interface::getVerify() const { // Only OpInterface supports the verify method. diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -4,6 +4,17 @@ include "mlir/IR/OpBase.td" +def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> { + let extraClassOf = "return $_op->someOtherMethod();"; +} + +// DECL: class ExtraClassOfInterface +// DECL: static bool classof(::mlir::Operation * base) { +// DECL-NEXT: if (!getInterfaceFor(base)) +// DECL-NEXT: return false; +// DECL-NEXT: return base->someOtherMethod(); +// DECL-NEXT: } + def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> { let extraSharedClassDeclaration = [{ bool sharedMethodDeclaration() { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -819,7 +819,7 @@ formatExtraDefinitions(op)), staticVerifierEmitter(staticVerifierEmitter), emitHelper(op, /*emitForOp=*/true) { - verifyCtx.withOp("(*this->getOperation())"); + verifyCtx.addSubst("_op", "(*this->getOperation())"); verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); genTraits(); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -108,6 +108,8 @@ StringRef interfaceBaseType; /// The name of the typename for the value template. StringRef valueTemplate; + /// The name of the substituion variable for the value. + StringRef substVar; /// The format context to use for methods. tblgen::FmtContext nonStaticMethodFmt; tblgen::FmtContext traitMethodFmt; @@ -121,11 +123,12 @@ valueType = "::mlir::Attribute"; interfaceBaseType = "AttributeInterface"; valueTemplate = "ConcreteAttr"; + substVar = "_attr"; StringRef castCode = "(tablegen_opaque_val.cast())"; - nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode); - traitMethodFmt.addSubst("_attr", + nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode); + traitMethodFmt.addSubst(substVar, "(*static_cast(this))"); - extraDeclsFmt.addSubst("_attr", "(*this)"); + extraDeclsFmt.addSubst(substVar, "(*this)"); } }; /// A specialized generator for operation interfaces. @@ -135,12 +138,13 @@ valueType = "::mlir::Operation *"; interfaceBaseType = "OpInterface"; valueTemplate = "ConcreteOp"; + substVar = "_op"; StringRef castCode = "(llvm::cast(tablegen_opaque_val))"; nonStaticMethodFmt.addSubst("_this", "impl") - .withOp(castCode) + .addSubst(substVar, castCode) .withSelf(castCode); - traitMethodFmt.withOp("(*static_cast(this))"); - extraDeclsFmt.withOp("(*this)"); + traitMethodFmt.addSubst(substVar, "(*static_cast(this))"); + extraDeclsFmt.addSubst(substVar, "(*this)"); } }; /// A specialized generator for type interfaces. @@ -150,11 +154,12 @@ valueType = "::mlir::Type"; interfaceBaseType = "TypeInterface"; valueTemplate = "ConcreteType"; + substVar = "_type"; StringRef castCode = "(tablegen_opaque_val.cast())"; - nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode); - traitMethodFmt.addSubst("_type", + nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode); + traitMethodFmt.addSubst(substVar, "(*static_cast(this))"); - extraDeclsFmt.addSubst("_type", "(*this)"); + extraDeclsFmt.addSubst(substVar, "(*this)"); } }; } // namespace @@ -434,7 +439,7 @@ assert(isa(interface) && "only OpInterface supports 'verify'"); tblgen::FmtContext verifyCtx; - verifyCtx.withOp("op"); + verifyCtx.addSubst("_op", "op"); os << llvm::formatv( " static ::mlir::LogicalResult {0}(::mlir::Operation *op) ", (interface.verifyWithRegions() ? "verifyRegionTrait" @@ -506,6 +511,17 @@ interface.getExtraSharedClassDeclaration()) os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt); + // Emit classof code if necessary. + if (std::optional extraClassOf = interface.getExtraClassOf()) { + auto extraClassOfFmt = tblgen::FmtContext(); + extraClassOfFmt.addSubst(substVar, "base"); + os << " static bool classof(" << valueType << " base) {\n" + << " if (!getInterfaceFor(base))\n" + " return false;\n" + << " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt) + << "\n }\n"; + } + os << "};\n"; os << "namespace detail {\n"; diff --git a/mlir/unittests/TableGen/FormatTest.cpp b/mlir/unittests/TableGen/FormatTest.cpp --- a/mlir/unittests/TableGen/FormatTest.cpp +++ b/mlir/unittests/TableGen/FormatTest.cpp @@ -105,12 +105,6 @@ EXPECT_THAT(result, StrEq("bbb")); } -TEST(FormatTest, PlaceHolderFmtStrWithOp) { - FmtContext ctx; - std::string result = std::string(tgfmt("$_op", &ctx.withOp("ooo"))); - EXPECT_THAT(result, StrEq("ooo")); -} - TEST(FormatTest, PlaceHolderMissingCtx) { std::string result = std::string(tgfmt("$_op", nullptr)); EXPECT_THAT(result, StrEq("$_op"));