diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -1066,10 +1066,9 @@ auto it = symbolToUsers.find(symbol); if (it == symbolToUsers.end()) return; - SetVector &users = it->second; // Replace the uses within the users of `symbol`. - for (Operation *user : users) + for (Operation *user : it->second) (void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user); // Move the current users of `symbol` to the new symbol if it is in the @@ -1077,13 +1076,16 @@ Operation *newSymbol = symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName); if (newSymbol != symbol) { - // Transfer over the users to the new symbol. - auto newIt = symbolToUsers.find(newSymbol); - if (newIt == symbolToUsers.end()) - symbolToUsers.try_emplace(newSymbol, std::move(users)); + // Transfer over the users to the new symbol. The reference to the old one + // is fetched again as the iterator is invalidated during the insertion. + auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector{}); + auto oldIt = symbolToUsers.find(symbol); + assert(oldIt != symbolToUsers.end() && "missing old users list"); + if (newIt.second) + newIt.first->second = std::move(oldIt->second); else - newIt->second.set_union(users); - symbolToUsers.erase(symbol); + newIt.first->second.set_union(oldIt->second); + symbolToUsers.erase(oldIt); } }