diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -14,6 +14,7 @@ #ifndef FIR_DIALECT_FIR_OPS #define FIR_DIALECT_FIR_OPS +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffects.td" diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -15,6 +15,7 @@ include "mlir/Dialect/GPU/GPUBase.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffects.td" // Type constraint accepting standard integers, indices and wrapped LLVM integer diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -14,6 +14,7 @@ #define LLVMIR_OPS include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffects.td" diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -16,6 +16,7 @@ #define SPIRV_STRUCTURE_OPS include "mlir/Dialect/SPIRV/SPIRVBase.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/SideEffects.td" diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -2,3 +2,8 @@ mlir_tablegen(OpAsmInterface.h.inc -gen-op-interface-decls) mlir_tablegen(OpAsmInterface.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIROpAsmInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td) +mlir_tablegen(SymbolInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(SymbolInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRSymbolInterfacesIncGen) diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -30,11 +30,10 @@ /// implicitly capture global values, and all external references must use /// Function arguments or attributes that establish a symbolic connection(e.g. /// symbols referenced by name via a string attribute). -class FuncOp - : public Op { +class FuncOp : public Op { public: using Op::Op; using Op::print; diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -31,7 +31,8 @@ : public Op< ModuleOp, OpTrait::ZeroOperands, OpTrait::ZeroResult, OpTrait::IsIsolatedFromAbove, OpTrait::SymbolTable, - OpTrait::SingleBlockImplicitTerminator::Impl> { + OpTrait::SingleBlockImplicitTerminator::Impl, + SymbolOpInterface::Trait> { public: using Op::Op; using Op::print; @@ -95,6 +96,13 @@ insertPt = Block::iterator(body->getTerminator()); body->getOperations().insert(insertPt, op); } + + //===--------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===--------------------------------------------------------------------===// + + /// A ModuleOp may optionally define a symbol. + bool isOptionalSymbol() { return true; } }; /// The ModuleTerminatorOp is a special terminator operation for the body of a diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1632,10 +1632,6 @@ // Op has the same operand and result element type (or type itself, if scalar). def SameOperandsAndResultElementType : NativeOpTrait<"SameOperandsAndResultElementType">; -// Op is a symbol. -def Symbol : NativeOpTrait<"Symbol">; -// Op defines a symbol table. -def SymbolTable : NativeOpTrait<"SymbolTable">; // Op is a terminator. def Terminator : NativeOpTrait<"IsTerminator">; @@ -1699,6 +1695,10 @@ // Specify the body of the verification function. `$_op` will be replaced with // the operation being verified. code verify = verifyBody; + + // An optional code block containing extra declarations to place in the + // interface trait declaration. + code extraTraitClassDeclaration = ""; } // This class represents a single, optionally static, interface method. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1353,6 +1353,7 @@ public: using Concept = typename Traits::Concept; template using Model = typename Traits::template Model; + using Base = OpInterface; OpInterface(Operation *op = nullptr) : Op(op), impl(op ? getInterfaceFor(op) : nullptr) { diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -0,0 +1,155 @@ +//===- SymbolInterfaces.td - Interfaces for symbol ops -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces and traits that can be used to define +// properties of symbol and symbol table operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_SYMBOLINTERFACES +#define MLIR_IR_SYMBOLINTERFACES + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// SymbolOpInterface +//===----------------------------------------------------------------------===// + +def Symbol : OpInterface<"SymbolOpInterface"> { + let description = [{ + This interface describes an operation that may define a `Symbol`. A `Symbol` + is a named operation that resides immediately within a region that defines + a `SymbolTable`. See [Symbols and SymbolTables](SymbolsAndSymbolTables.md) + for more details. + }]; + + let methods = [ + InterfaceMethod<"Returns the name of this symbol.", + "StringRef", "getName", (ins), [{ + // Don't rely on the trait implementation as optional symbol operations + // may override this. + return mlir::SymbolTable::getSymbolName(op); + }], /*defaultImplementation=*/[{ + return mlir::SymbolTable::getSymbolName(this->getOperation()); + }] + >, + InterfaceMethod<"Sets the name of this symbol.", + "void", "setName", (ins "StringRef":$name), [{}], + /*defaultImplementation=*/[{ + this->getOperation()->setAttr( + mlir::SymbolTable::getSymbolAttrName(), + StringAttr::get(name, this->getOperation()->getContext())); + }] + >, + InterfaceMethod<"Gets the visibility of this symbol.", + "mlir::SymbolTable::Visibility", "getVisibility", (ins), [{}], + /*defaultImplementation=*/[{ + return mlir::SymbolTable::getSymbolVisibility(this->getOperation()); + }] + >, + InterfaceMethod<"Sets the visibility of this symbol.", + "void", "setVisibility", (ins "mlir::SymbolTable::Visibility":$vis), [{}], + /*defaultImplementation=*/[{ + mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis); + }] + >, + InterfaceMethod<[{ + Get all of the uses of the current symbol that are nested within the + given operation 'from'. + Note: See mlir::SymbolTable::getSymbolUses for more details. + }], + "Optional<::mlir::SymbolTable::UseRange>", "getSymbolUses", + (ins "Operation *":$from), [{}], + /*defaultImplementation=*/[{ + return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from); + }] + >, + InterfaceMethod<[{ + Return if the current symbol is known to have no uses that are nested + within the given operation 'from'. + Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details. + }], + "bool", "symbolKnownUseEmpty", (ins "Operation *":$from), [{}], + /*defaultImplementation=*/[{ + return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), + from); + }] + >, + InterfaceMethod<[{ + Attempt to replace all uses of the current symbol with the provided + symbol 'newSymbol' that are nested within the given operation 'from'. + Note: See mlir::SymbolTable::replaceAllSymbolUses for more details. + }], + "LogicalResult", "replaceAllSymbolUses", (ins "StringRef":$newSymbol, + "Operation *":$from), [{}], + /*defaultImplementation=*/[{ + return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(), + newSymbol, from); + }] + >, + InterfaceMethod<[{ + Returns true if this operation optionally defines a symbol based on the + presence of the symbol name. + }], + "bool", "isOptionalSymbol", (ins), [{}], + /*defaultImplementation=*/[{ return false; }] + >, + InterfaceMethod<[{ + Returns true if this operation can be discarded if it has no remaining + symbol uses. + }], + "bool", "canDiscardOnUseEmpty", (ins), [{}], + /*defaultImplementation=*/[{ + // By default, base this on the visibility alone. A symbol can be + // discarded as long as it is not public. Only public symbols may be + // visible from outside of the IR. + return getVisibility() != ::mlir::SymbolTable::Visibility::Public; + }] + >, + ]; + + let verify = [{ + // If this is an optional symbol, bail out early if possible. + auto concreteOp = cast($_op); + if (concreteOp.isOptionalSymbol()) { + if(!concreteOp.getAttr(::mlir::SymbolTable::getSymbolAttrName())) + return success(); + } + return ::mlir::OpTrait::impl::verifySymbol($_op); + }]; + + let extraClassDeclaration = [{ + using Visibility = mlir::SymbolTable::Visibility; + + /// Custom classof that handles the case where the symbol is optional. + static bool classof(Operation *op) { + return Base::classof(op) + && op->getAttr(::mlir::SymbolTable::getSymbolAttrName()); + } + + /// Returns true if this symbol has nested visibility. + bool isNested() { return getVisibility() == Visibility::Nested; } + /// Returns true if this symbol has private visibility. + bool isPrivate() { return getVisibility() == Visibility::Private; } + /// Returns true if this symbol has public visibility. + bool isPublic() { return getVisibility() == Visibility::Public; } + }]; + + let extraTraitClassDeclaration = [{ + using Visibility = mlir::SymbolTable::Visibility; + }]; +} + +//===----------------------------------------------------------------------===// +// Symbol Traits +//===----------------------------------------------------------------------===// + +// Op defines a symbol table. +def SymbolTable : NativeOpTrait<"SymbolTable">; + +#endif // MLIR_IR_SYMBOLINTERFACES diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -71,9 +71,6 @@ Nested, }; - /// Returns true if the given operation defines a symbol. - static bool isSymbol(Operation *op); - /// Returns the name of the given symbol operation. static StringRef getSymbolName(Operation *symbol); /// Sets the name of the given symbol operation. @@ -229,68 +226,11 @@ } }; -/// A trait used to define a symbol that can be used on operations within a -/// symbol table. Operations using this trait must adhere to the following: -/// * Have a StringAttr attribute named 'SymbolTable::getSymbolAttrName()'. -template -class Symbol : public TraitBase { -public: - using Visibility = mlir::SymbolTable::Visibility; - - static LogicalResult verifyTrait(Operation *op) { - return impl::verifySymbol(op); - } - - /// Returns the name of this symbol. - StringRef getName() { - return this->getOperation() - ->template getAttrOfType( - mlir::SymbolTable::getSymbolAttrName()) - .getValue(); - } - - /// Set the name of this symbol. - void setName(StringRef name) { - this->getOperation()->setAttr( - mlir::SymbolTable::getSymbolAttrName(), - StringAttr::get(name, this->getOperation()->getContext())); - } - - /// Returns the visibility of the current symbol. - Visibility getVisibility() { - return mlir::SymbolTable::getSymbolVisibility(this->getOperation()); - } - - /// Sets the visibility of the current symbol. - void setVisibility(Visibility vis) { - mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis); - } - - /// Get all of the uses of the current symbol that are nested within the given - /// operation 'from'. - /// Note: See mlir::SymbolTable::getSymbolUses for more details. - Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) { - return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from); - } - - /// Return if the current symbol is known to have no uses that are nested - /// within the given operation 'from'. - /// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details. - bool symbolKnownUseEmpty(Operation *from) { - return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from); - } +} // end namespace OpTrait - /// Attempt to replace all uses of the current symbol with the provided symbol - /// 'newSymbol' that are nested within the given operation 'from'. - /// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details. - LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol, - Operation *from) { - return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(), - newSymbol, from); - } -}; +/// Include the generated symbol interfaces. +#include "mlir/IR/SymbolInterfaces.h.inc" -} // end namespace OpTrait } // end namespace mlir #endif // MLIR_IR_SYMBOLTABLE_H diff --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/OpInterfaces.h --- a/mlir/include/mlir/TableGen/OpInterfaces.h +++ b/mlir/include/mlir/TableGen/OpInterfaces.h @@ -89,6 +89,9 @@ // Return the interfaces extra class declaration code. llvm::Optional getExtraClassDeclaration() const; + // Return the traits extra class declaration code. + llvm::Optional getExtraTraitClassDeclaration() const; + // Return the verify method body if it has one. llvm::Optional getVerify() const; diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -8,6 +8,7 @@ DEPENDS MLIRCallInterfacesIncGen MLIROpAsmInterfacesIncGen + MLIRSymbolInterfacesIncGen ) target_link_libraries(MLIRIR PUBLIC diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -146,11 +146,6 @@ setSymbolName(symbol, nameBuffer); } -/// Returns true if the given operation defines a symbol. -bool SymbolTable::isSymbol(Operation *op) { - return op->hasTrait() || getNameIfSymbol(op).hasValue(); -} - /// Returns the name of the given symbol operation. StringRef SymbolTable::getSymbolName(Operation *symbol) { Optional name = getNameIfSymbol(symbol); @@ -866,3 +861,10 @@ Region *from) { return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); } + +//===----------------------------------------------------------------------===// +// Symbol Interfaces +//===----------------------------------------------------------------------===// + +/// Include the generated symbol interfaces. +#include "mlir/IR/SymbolInterfaces.cpp.inc" diff --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/OpInterfaces.cpp --- a/mlir/lib/TableGen/OpInterfaces.cpp +++ b/mlir/lib/TableGen/OpInterfaces.cpp @@ -92,6 +92,12 @@ return value.empty() ? llvm::Optional() : value; } +// Return the traits extra class declaration code. +llvm::Optional OpInterface::getExtraTraitClassDeclaration() const { + auto value = def->getValueAsString("extraTraitClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} + // Return the body for this method if it has one. llvm::Optional OpInterface::getVerify() const { auto value = def->getValueAsString("verify"); diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -31,26 +31,6 @@ // Symbol Use Tracking //===----------------------------------------------------------------------===// -/// Returns true if this operation can be discarded if it is a symbol and has no -/// uses. 'allUsesVisible' corresponds to if the parent symbol table is hidden -/// from above. -static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) { - if (!SymbolTable::isSymbol(op)) - return false; - - // TODO: This is essentially the same logic from SymbolDCE. Remove this when - // we have a 'Symbol' interface. - // Private symbols are always initially considered dead. - SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op); - if (visibility == mlir::SymbolTable::Visibility::Private) - return true; - // We only include nested visibility here if all uses are visible. - if (allUsesVisible && visibility == SymbolTable::Visibility::Nested) - return true; - // Otherwise, public symbols are never removable. - return false; -} - /// Walk all of the symbol table operations nested with 'op' along with a /// boolean signifying if the symbols within can be treated as if all uses are /// visible. The provided callback is invoked with the symbol table operation, @@ -59,9 +39,8 @@ static void walkSymbolTables(Operation *op, bool allSymUsesVisible, function_ref callback) { if (op->hasTrait()) { - allSymUsesVisible = allSymUsesVisible || !SymbolTable::isSymbol(op) || - SymbolTable::getSymbolVisibility(op) == - SymbolTable::Visibility::Private; + SymbolOpInterface symbol = dyn_cast(op); + allSymUsesVisible = allSymUsesVisible || !symbol || symbol.isPrivate(); callback(op, allSymUsesVisible); } else { // Otherwise if 'op' is not a symbol table, any nested symbols are @@ -171,8 +150,11 @@ // If this is a callgraph operation, check to see if it is discardable. if (auto callable = dyn_cast(&op)) { if (auto *node = cg.lookupNode(callable.getCallableRegion())) { - if (canDiscardSymbolOnUseEmpty(&op, allUsesVisible)) + SymbolOpInterface symbol = dyn_cast(&op); + if (symbol && (allUsesVisible || symbol.isPrivate()) && + symbol.canDiscardOnUseEmpty()) { discardableSymNodeUses.try_emplace(node, 0); + } continue; } } @@ -224,7 +206,7 @@ bool CGUseList::isDead(CallGraphNode *node) const { // If the parent operation isn't a symbol, simply check normal SSA deadness. Operation *nodeOp = node->getCallableRegion()->getParentOp(); - if (!SymbolTable::isSymbol(nodeOp)) + if (!isa(nodeOp)) return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty(); // Otherwise, check the number of symbol uses. @@ -235,7 +217,7 @@ bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const { // If this isn't a symbol node, check for side-effects and SSA use count. Operation *nodeOp = node->getCallableRegion()->getParentOp(); - if (!SymbolTable::isSymbol(nodeOp)) + if (!isa(nodeOp)) return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse(); // Otherwise, check the number of symbol uses. diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -43,10 +43,9 @@ // A flag that signals if the top level symbol table is hidden, i.e. not // accessible from parent scopes. bool symbolTableIsHidden = true; - if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) { - symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) == - SymbolTable::Visibility::Private; - } + SymbolOpInterface symbol = dyn_cast(symbolTableOp); + if (symbolTableOp->getParentOp() && symbol) + symbolTableIsHidden = symbol.isPrivate(); // Compute the set of live symbols within the symbol table. DenseSet liveSymbols; @@ -61,7 +60,7 @@ for (auto &block : nestedSymbolTable->getRegion(0)) { for (Operation &op : llvm::make_early_inc_range(block.without_terminator())) { - if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op)) + if (isa(&op) && !liveSymbols.count(&op)) op.erase(); } } @@ -80,30 +79,16 @@ // Walk the symbols within the current symbol table, marking the symbols that // are known to be live. for (auto &block : symbolTableOp->getRegion(0)) { + // Add all non-symbols or symbols that can't be discarded. for (Operation &op : block.without_terminator()) { - // Always add non symbol operations to the worklist. - if (!SymbolTable::isSymbol(&op)) { + SymbolOpInterface symbol = dyn_cast(&op); + if (!symbol) { worklist.push_back(&op); continue; } - - // Check the visibility to see if this symbol may be referenced - // externally. - SymbolTable::Visibility visibility = - SymbolTable::getSymbolVisibility(&op); - - // Private symbols are always initially considered dead. - if (visibility == mlir::SymbolTable::Visibility::Private) - continue; - // We only include nested visibility here if the symbol table isn't - // hidden. - if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested) - continue; - - // TODO(riverriddle) Add hooks here to allow symbols to provide additional - // information, e.g. linkage can be used to drop some symbols that may - // otherwise be considered "live". - if (liveSymbols.insert(&op).second) + bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) && + symbol.canDiscardOnUseEmpty(); + if (!isDiscardable && liveSymbols.insert(&op).second) worklist.push_back(&op); } } @@ -117,10 +102,9 @@ if (op->hasTrait()) { // The internal symbol table is hidden if the parent is, if its not a // symbol, or if it is a private symbol. - bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) || - SymbolTable::getSymbolVisibility(op) == - SymbolTable::Visibility::Private; - if (failed(computeLiveness(op, symbolIsHidden, liveSymbols))) + SymbolOpInterface symbol = dyn_cast(op); + bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate(); + if (failed(computeLiveness(op, symIsHidden, liveSymbols))) return failure(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -11,6 +11,7 @@ include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffects.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -66,7 +66,7 @@ // Walk nested symbols. SmallVector deadFunctions; module.getBodyRegion().walk([&](Operation *nestedOp) { - if (SymbolTable::isSymbol(nestedOp)) + if (isa(nestedOp)) return operateOnSymbol(nestedOp, module, deadFunctions); return WalkResult::advance(); }); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -174,6 +174,8 @@ os << " static LogicalResult verifyTrait(Operation* op) {\n" << std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n"; } + if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) + os << extraTraitDecls << "\n"; os << " };\n"; }