diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -27,6 +27,11 @@ symbol->getAttrOfType(SymbolTable::getSymbolAttrName()); return nameAttr ? nameAttr.getValue() : Optional(); } +static Optional getNameIfSymbol(Operation *symbol, + Identifier symbolAttrNameId) { + auto nameAttr = symbol->getAttrOfType(symbolAttrNameId); + return nameAttr ? nameAttr.getValue() : Optional(); +} /// Computes the nested symbol reference attribute for the symbol 'symbolName' /// that are usable within the symbol table operations from 'symbol' as far up @@ -49,12 +54,15 @@ // Collect references until 'symbolTableOp' reaches 'within'. SmallVector nestedRefs(1, leafRef); + Identifier symbolNameId = + Identifier::get(SymbolTable::getSymbolAttrName(), ctx); do { // Each parent of 'symbol' should define a symbol table. if (!symbolTableOp->hasTrait()) return failure(); // Each parent of 'symbol' should also be a symbol. - Optional symbolTableName = getNameIfSymbol(symbolTableOp); + Optional symbolTableName = + getNameIfSymbol(symbolTableOp, symbolNameId); if (!symbolTableName) return failure(); results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx)); @@ -106,8 +114,10 @@ assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) && "expected operation to have a single block"); + Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(), + symbolTableOp->getContext()); for (auto &op : symbolTableOp->getRegion(0).front()) { - Optional name = getNameIfSymbol(&op); + Optional name = getNameIfSymbol(&op, symbolNameId); if (!name) continue; @@ -269,8 +279,10 @@ assert(symbolTableOp->hasTrait()); // Look for a symbol with the given name. + Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(), + symbolTableOp->getContext()); for (auto &op : symbolTableOp->getRegion(0).front().without_terminator()) - if (getNameIfSymbol(&op) == symbol) + if (getNameIfSymbol(&op, symbolNameId) == symbol) return &op; return nullptr; }