diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -11,6 +11,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" namespace mlir { @@ -260,6 +261,43 @@ DenseMap> symbolTables; }; +//===----------------------------------------------------------------------===// +// SymbolUserMap +//===----------------------------------------------------------------------===// + +/// This class represents a map of symbols to users, and provides efficient +/// implementations of symbol queries related to users; such as collecting the +/// users of a symbol, replacing all uses, etc. +class SymbolUserMap { +public: + /// Collect all symbol uses for operations nested under 'limit'. A reference + /// to the provided symbol table collection is kept by the user map to ensure + /// efficient lookups, thus the lifetime should extend beyond that of this + /// map. + SymbolUserMap(SymbolTableCollection &symbolTable, Operation *limit); + + /// Return the users of the provided symbol operation. + ArrayRef getUsers(Operation *symbol) const { + auto it = symbolToUsers.find(symbol); + return it != symbolToUsers.end() ? it->second.getArrayRef() : llvm::None; + } + + /// Return true if the given symbol has no uses. + bool use_empty(Operation *symbol) const { + return !symbolToUsers.count(symbol); + } + + /// Replace all of the uses of the given symbol with `newSymbolName`. + void replaceAllUsesWith(Operation *symbol, StringRef newSymbolName); + +private: + /// A reference to the symbol table used to construct this map. + SymbolTableCollection &symbolTable; + + /// A map of symbol operations to symbol users. + DenseMap> symbolToUsers; +}; + //===----------------------------------------------------------------------===// // SymbolTable Trait Types //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -1000,6 +1000,62 @@ return *it.first->second; } +//===----------------------------------------------------------------------===// +// SymbolUserMap +//===----------------------------------------------------------------------===// + +/// Collect all symbol uses for operations nested under 'limit'. +SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable, + Operation *limit) + : symbolTable(symbolTable) { + // Walk each of the symbol tables looking for discardable callgraph nodes. + SmallVector symbols; + auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) { + for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) { + auto symbolUses = SymbolTable::getSymbolUses(&nestedOp); + assert(symbolUses && "expected uses to be valid"); + + for (const SymbolTable::SymbolUse &use : *symbolUses) { + symbols.clear(); + (void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(), + symbols); + for (Operation *symbolOp : symbols) + symbolToUsers[symbolOp].insert(use.getUser()); + } + } + }; + SymbolTable::walkSymbolTables(limit, /*allSymUsesVisible=*/!limit->getBlock(), + walkFn); +} + +/// Replace all of the uses of the given symbol with `newSymbolName`. +void SymbolUserMap::replaceAllUsesWith(Operation *symbol, + StringRef newSymbolName) { + ArrayRef users = getUsers(symbol); + if (users.empty()) + return; + + // Replace the uses within the users of `symbol`. + for (Operation *user : users) + (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user); + + // Move the current users of `symbol` to the new symbol if it is in the + // symbol table. + Operation *newSymbol = + symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName); + if (newSymbol != symbol) { + auto oldIt = symbolToUsers.find(symbol); + + // Transfer over the users to the new symbol. + auto newIt = symbolToUsers.find(newSymbol); + if (newIt == symbolToUsers.end()) + symbolToUsers.try_emplace(newSymbol, std::move(oldIt->second)); + else + newIt->second.set_union(oldIt->second); + symbolToUsers.erase(oldIt); + } +} + //===----------------------------------------------------------------------===// // Visibility parsing implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -90,16 +90,20 @@ struct SymbolReplacementPass : public PassWrapper> { void runOnOperation() override { - auto module = getOperation(); + ModuleOp module = getOperation(); + + // Don't try to replace if we can't collect symbol uses. + if (!SymbolTable::getSymbolUses(&module.getBodyRegion())) + return; - // Walk nested functions and modules. + SymbolTableCollection symbolTable; + SymbolUserMap symbolUsers(symbolTable, module); module.getBodyRegion().walk([&](Operation *nestedOp) { StringAttr newName = nestedOp->getAttrOfType("sym.new_name"); if (!newName) return; - if (succeeded(SymbolTable::replaceAllSymbolUses( - nestedOp, newName.getValue(), &module.getBodyRegion()))) - SymbolTable::setSymbolName(nestedOp, newName.getValue()); + symbolUsers.replaceAllUsesWith(nestedOp, newName.getValue()); + SymbolTable::setSymbolName(nestedOp, newName.getValue()); }); } };