diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -13,6 +13,7 @@ #include "mlir/IR/OpDefinition.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/RWMutex.h" namespace mlir { @@ -281,10 +282,66 @@ SymbolTable &getSymbolTable(Operation *op); private: + friend class LockedSymbolTableCollection; + /// The constructed symbol tables nested within this table. DenseMap> symbolTables; }; +//===----------------------------------------------------------------------===// +// LockedSymbolTableCollection +//===----------------------------------------------------------------------===// + +/// This class implements a lock-based shared wrapper around a symbol table +/// collection that allows shared access to the collection of symbol tables. +/// This class does not protect shared access to individual symbol tables. +/// `SymbolTableCollection` lazily instantiates `SymbolTable` instances for +/// symbol table operations, making read operations not thread-safe. This class +/// provides a thread-safe `lookupSymbolIn` implementation by synchronizing the +/// lazy `SymbolTable` lookup. +class LockedSymbolTableCollection : public SymbolTableCollection { +public: + explicit LockedSymbolTableCollection(SymbolTableCollection &collection) + : collection(collection) {} + + /// Look up a symbol with the specified name within the specified symbol table + /// operation, returning null if no such name exists. + Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol); + /// Look up a symbol with the specified name within the specified symbol table + /// operation, returning null if no such name exists. + Operation *lookupSymbolIn(Operation *symbolTableOp, FlatSymbolRefAttr symbol); + /// Look up a potentially nested symbol within the specified symbol table + /// operation, returning null if no such symbol exists. + Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); + + /// Lookup a symbol of a particular kind within the specified symbol table, + /// returning null if the symbol was not found. + template + T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) { + return dyn_cast_or_null( + lookupSymbolIn(symbolTableOp, std::forward(name))); + } + + /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced + /// by a given SymbolRefAttr when resolved within the provided symbol table + /// operation. Returns failure if any of the nested references could not be + /// resolved. + LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, + SmallVectorImpl &symbols); + +private: + /// Get the symbol table for the symbol table operation, constructing if it + /// does not exist. This function provides thread safety over `collection` + /// by locking when performing the lookup and when inserting + /// lazily-constructed symbol tables. + SymbolTable &getSymbolTable(Operation *symbolTableOp); + + /// The symbol tables to manage. + SymbolTableCollection &collection; + /// The mutex protecting access to the symbol table collection. + llvm::sys::SmartRWMutex mutex; +}; + //===----------------------------------------------------------------------===// // SymbolUserMap //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -916,6 +916,58 @@ return *it.first->second; } +//===----------------------------------------------------------------------===// +// LockedSymbolTableCollection +//===----------------------------------------------------------------------===// + +Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, + StringAttr symbol) { + return getSymbolTable(symbolTableOp).lookup(symbol); +} + +Operation * +LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, + FlatSymbolRefAttr symbol) { + return lookupSymbolIn(symbolTableOp, symbol.getAttr()); +} + +Operation *LockedSymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, + SymbolRefAttr name) { + SmallVector symbols; + if (failed(lookupSymbolIn(symbolTableOp, name, symbols))) + return nullptr; + return symbols.back(); +} + +LogicalResult LockedSymbolTableCollection::lookupSymbolIn( + Operation *symbolTableOp, SymbolRefAttr name, + SmallVectorImpl &symbols) { + auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) { + return lookupSymbolIn(symbolTableOp, symbol); + }; + return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); +} + +SymbolTable & +LockedSymbolTableCollectio::getSymbolTable(Operation *symbolTableOp) { + assert(symbolTableOp->hasTrait()); + // Try to find an existing symbol table. + { + llvm::sys::SmartScopedReader lock(mutex); + auto it = collection.symbolTables.find(symbolTableOp); + if (it != collection.symbolTables.end()) + return *it->second; + } + // Create a symbol table for the operation. Perform construction outside of + // the critical section. + auto symbolTable = std::make_unique(symbolTableOp); + // Insert the constructed symbol table. + llvm::sys::SmartScopedWriter lock(mutex); + return *collection.symbolTables + .insert({symbolTableOp, std::move(symbolTable)}) + .first->second; +} + //===----------------------------------------------------------------------===// // SymbolUserMap //===----------------------------------------------------------------------===//