diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -49,8 +49,9 @@ /// Insert a new symbol into the table, and rename it as necessary to avoid /// collisions. Also insert at the specified location in the body of the /// associated operation if it is not already there. It is asserted that the - /// symbol is not inside another operation. - void insert(Operation *symbol, Block::iterator insertPt = {}); + /// symbol is not inside another operation. Return the name of the symbol + /// after insertion as attribute. + StringAttr insert(Operation *symbol, Block::iterator insertPt = {}); /// Return the name of the attribute used for symbol names. static StringRef getSymbolAttrName() { return "sym_name"; } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -151,8 +151,9 @@ // TODO: Consider if this should be renamed to something like insertOrUpdate /// Insert a new symbol into the table and associated operation if not already -/// there and rename it as necessary to avoid collisions. -void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { +/// there and rename it as necessary to avoid collisions. Return the name of +/// the symbol after insertion as attribute. +StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { // The symbol cannot be the child of another op and must be the child of the // symbolTableOp after this. // @@ -180,10 +181,10 @@ // detected. StringAttr name = getSymbolName(symbol); if (symbolTable.insert({name, symbol}).second) - return; + return name; // If the symbol was already in the table, also return. if (symbolTable.lookup(name) == symbol) - return; + return name; // If a conflict was detected, then the symbol will not have been added to // the symbol table. Try suffixes until we get to a unique name that works. SmallString<128> nameBuffer(name.getValue()); @@ -199,6 +200,7 @@ } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol}) .second); setSymbolName(symbol, nameBuffer); + return getSymbolName(symbol); } /// Returns the name of the given symbol operation.