diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1393,7 +1393,7 @@ /// Inherit the base class constructor. using InterfaceBase::InterfaceBase; -private: +protected: /// Returns the impl interface instance for the given operation. static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { // Access the raw interface from the abstract operation. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -145,8 +145,11 @@ let extraClassDeclaration = [{ /// Custom classof that handles the case where the symbol is optional. static bool classof(Operation *op) { - return Base::classof(op) - && op->getAttr(::mlir::SymbolTable::getSymbolAttrName()); + auto *concept = getInterfaceFor(op); + if (!concept) + return false; + return !concept->isOptionalSymbol(op) || + op->getAttr(::mlir::SymbolTable::getSymbolAttrName()); } }]; diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -210,6 +210,40 @@ unsigned uniquingCounter = 0; }; +//===----------------------------------------------------------------------===// +// SymbolTableCollection +//===----------------------------------------------------------------------===// + +/// This class represents a collection of `SymbolTable`s. This simplifies +/// certain algorithms that run recursively on nested symbol tables. Symbol +/// tables are constructed lazily to reduce the upfront cost of constructing +/// unnecessary tables. +class SymbolTableCollection { +public: + /// Look up a symbol with the specified name within the specified symbol table + /// operation, returning null if no such name exists. + Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol); + Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); + template + T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const { + return dyn_cast_or_null( + lookupSymbolIn(symbolTableOp, std::forward(name))); + } + /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced + /// by a given SymbolRefAttr when resolved within the provided symbol table + /// operation. Returns failure if any of the nested references could not be + /// resolved. + LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, + SmallVectorImpl &symbols); + + /// Lookup, or create, a symbol table for an operation. + SymbolTable &getSymbolTable(Operation *op); + +private: + /// The constructed symbol tables nested within this table. + DenseMap> symbolTables; +}; + //===----------------------------------------------------------------------===// // SymbolTable Trait Types //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -258,13 +258,16 @@ return resolvedSymbols.back(); } -LogicalResult -SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol, - SmallVectorImpl &symbols) { +/// Internal implementation of `lookupSymbolIn` that allows for specialized +/// implementations of the lookup function. +static LogicalResult lookupSymbolInImpl( + Operation *symbolTableOp, SymbolRefAttr symbol, + SmallVectorImpl &symbols, + function_ref lookupSymbolFn) { assert(symbolTableOp->hasTrait()); // Lookup the root reference for this symbol. - symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference()); + symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference()); if (!symbolTableOp) return failure(); symbols.push_back(symbolTableOp); @@ -281,15 +284,24 @@ // Otherwise, lookup each of the nested non-leaf references and ensure that // each corresponds to a valid symbol table. for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { - symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue()); + symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getValue()); if (!symbolTableOp || !symbolTableOp->hasTrait()) return failure(); symbols.push_back(symbolTableOp); } - symbols.push_back(lookupSymbolIn(symbolTableOp, symbol.getLeafReference())); + symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference())); return success(symbols.back()); } +LogicalResult +SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol, + SmallVectorImpl &symbols) { + auto lookupFn = [](Operation *symbolTableOp, StringRef symbol) { + return lookupSymbolIn(symbolTableOp, symbol); + }; + return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn); +} + /// Returns the operation registered with the given symbol name within the /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns /// nullptr if no valid symbol was found. @@ -887,6 +899,42 @@ return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); } +//===----------------------------------------------------------------------===// +// SymbolTableCollection +//===----------------------------------------------------------------------===// + +Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, + StringRef symbol) { + return getSymbolTable(symbolTableOp).lookup(symbol); +} +Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, + SymbolRefAttr name) { + SmallVector symbols; + if (failed(lookupSymbolIn(symbolTableOp, name, symbols))) + return nullptr; + return symbols.back(); +} +/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by +/// a given SymbolRefAttr. Returns failure if any of the nested references could +/// not be resolved. +LogicalResult +SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, + SymbolRefAttr name, + SmallVectorImpl &symbols) { + auto lookupFn = [this](Operation *symbolTableOp, StringRef symbol) { + return lookupSymbolIn(symbolTableOp, symbol); + }; + return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); +} + +/// Lookup, or create, a symbol table for an operation. +SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) { + auto it = symbolTables.try_emplace(op, nullptr); + if (it.second) + it.first->second = std::make_unique(op); + return *it.first->second; +} + //===----------------------------------------------------------------------===// // Symbol Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -24,6 +24,7 @@ /// `symbolTableIsHidden` is true if this symbol table is known to be /// unaccessible from operations in its parent regions. LogicalResult computeLiveness(Operation *symbolTableOp, + SymbolTableCollection &symbolTable, bool symbolTableIsHidden, DenseSet &liveSymbols); }; @@ -49,7 +50,9 @@ // Compute the set of live symbols within the symbol table. DenseSet liveSymbols; - if (failed(computeLiveness(symbolTableOp, symbolTableIsHidden, liveSymbols))) + SymbolTableCollection symbolTable; + if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden, + liveSymbols))) return signalPassFailure(); // After computing the liveness, delete all of the symbols that were found to @@ -71,6 +74,7 @@ /// `symbolTableIsHidden` is true if this symbol table is known to be /// unaccessible from operations in its parent regions. LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, + SymbolTableCollection &symbolTable, bool symbolTableIsHidden, DenseSet &liveSymbols) { // A worklist of live operations to propagate uses from. @@ -104,7 +108,7 @@ // symbol, or if it is a private symbol. SymbolOpInterface symbol = dyn_cast(op); bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate(); - if (failed(computeLiveness(op, symIsHidden, liveSymbols))) + if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols))) return failure(); } @@ -120,7 +124,7 @@ for (const SymbolTable::SymbolUse &use : *uses) { // Lookup the symbols referenced by this use. resolvedSymbols.clear(); - if (failed(SymbolTable::lookupSymbolIn( + if (failed(symbolTable.lookupSymbolIn( op->getParentOp(), use.getSymbolRef(), resolvedSymbols))) { return use.getUser()->emitError() << "unable to resolve reference to symbol "