diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -41,7 +41,10 @@ return dyn_cast_or_null(lookup(name)); } - /// Erase the given symbol from the table. + /// Remove the given symbol from the table, without deleting it. + void remove(Operation *op); + + /// Erase the given symbol from the table and delete the operation. void erase(Operation *symbol); /// Insert a new symbol into the table, and rename it as necessary to avoid diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -146,19 +146,21 @@ return symbolTable.lookup(name); } -/// Erase the given symbol from the table. -void SymbolTable::erase(Operation *symbol) { - StringAttr name = getNameIfSymbol(symbol); +void SymbolTable::remove(Operation *op) { + StringAttr name = getNameIfSymbol(op); assert(name && "expected valid 'name' attribute"); - assert(symbol->getParentOp() == symbolTableOp && + assert(op->getParentOp() == symbolTableOp && "expected this operation to be inside of the operation with this " "SymbolTable"); auto it = symbolTable.find(name); - if (it != symbolTable.end() && it->second == symbol) { + if (it != symbolTable.end() && it->second == op) symbolTable.erase(it); - symbol->erase(); - } +} + +void SymbolTable::erase(Operation *symbol) { + remove(symbol); + symbol->erase(); } // TODO: Consider if this should be renamed to something like insertOrUpdate