diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -38,7 +38,8 @@ /// Insert a new symbol into the table, and rename it as necessary to avoid /// collisions. Also insert at the specified location in the body of the - /// associated operation. + /// associated operation if it is not already there. It is asserted that the + /// symbol is not inside another operation. void insert(Operation *symbol, Block::iterator insertPt = {}); /// Return the name of the attribute used for symbol names. diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -151,23 +151,35 @@ } } -/// Insert a new symbol into the table and associated operation, and rename it -/// as necessary to avoid collisions. +// TODO: Consider if this should be renamed to something like insertOrUpdate +/// Insert a new symbol into the table and associated operation if not already +/// there and rename it as necessary to avoid collisions. void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { - auto &body = symbolTableOp->getRegion(0).front(); - if (insertPt == Block::iterator() || insertPt == body.end()) - insertPt = Block::iterator(body.getTerminator()); - - assert(insertPt->getParentOp() == symbolTableOp && - "expected insertPt to be in the associated module operation"); - - body.getOperations().insert(insertPt, symbol); + // The symbol cannot be the child of another op and must be the child of the + // symbolTableOp after this. + // + // TODO: consider if SymbolTable's constructor should behave the same. + if (!symbol->getParentOp()) { + auto &body = symbolTableOp->getRegion(0).front(); + if (insertPt == Block::iterator() || insertPt == body.end()) + insertPt = Block::iterator(body.getTerminator()); + + assert(insertPt->getParentOp() == symbolTableOp && + "expected insertPt to be in the associated module operation"); + + body.getOperations().insert(insertPt, symbol); + } + assert(symbol->getParentOp() == symbolTableOp && + "symbol is already inserted in another op"); // Add this symbol to the symbol table, uniquing the name if a conflict is // detected. StringRef name = getSymbolName(symbol); if (symbolTable.insert({name, symbol}).second) return; + // If the symbol was already in the table, also return. + if (symbolTable.lookup(name) == symbol) + return; // If a conflict was detected, then the symbol will not have been added to // the symbol table. Try suffixes until we get to a unique name that works. SmallString<128> nameBuffer(name);