diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -50,16 +50,27 @@
   // Symbol Utilities
   //===--------------------------------------------------------------------===//
 
+  /// Returns true if the given operation defines a symbol.
+  static bool isSymbol(Operation *op);
+
+  /// Returns the name of the given symbol operation.
+  static StringRef getSymbolName(Operation *symbol);
+  /// Sets the name of the given symbol operation.
+  static void setSymbolName(Operation *symbol, StringRef name);
+
   /// Returns the operation registered with the given symbol name with the
   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
   /// with the 'OpTrait::SymbolTable' trait.
   static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
+  static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
 
   /// Returns the operation registered with the given symbol name within the
   /// closest parent operation of, or including, 'from' with the
   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
   /// found.
   static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
+  static Operation *lookupNearestSymbolFrom(Operation *from,
+                                            SymbolRefAttr symbol);
 
   /// This class represents a specific symbol use.
   class SymbolUse {
@@ -110,6 +121,7 @@
   /// symbol table, and not the op itself. This function returns None if there
   /// are any unknown operations that may potentially be symbol tables.
   static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
+  static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
 
   /// Return if the given symbol is known to have no uses that are nested
   /// within the given operation 'from'. This does not traverse into any nested
@@ -120,6 +132,7 @@
   /// tables. This doesn't necessarily mean that there are no uses, we just
   /// can't conservatively prove it.
   static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
+  static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
 
   /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
   /// provided symbol 'newSymbol' that are nested within the given operation
@@ -132,6 +145,9 @@
   LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
                                                            StringRef newSymbol,
                                                            Operation *from);
+  LLVM_NODISCARD static LogicalResult
+  replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
+                       Operation *from);
 
 private:
   Operation *symbolTableOp;
@@ -207,14 +223,14 @@
   /// operation 'from'.
   /// Note: See mlir::SymbolTable::getSymbolUses for more details.
   Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) {
-    return ::mlir::SymbolTable::getSymbolUses(getName(), from);
+    return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
   }
 
   /// Return if the current symbol is known to have no uses that are nested
   /// within the given operation 'from'.
   /// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
   bool symbolKnownUseEmpty(Operation *from) {
-    return ::mlir::SymbolTable::symbolKnownUseEmpty(getName(), from);
+    return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from);
   }
 
   /// Attempt to replace all uses of the current symbol with the provided symbol
@@ -222,8 +238,8 @@
   /// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
   LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol,
                                                     Operation *from) {
-    return ::mlir::SymbolTable::replaceAllSymbolUses(getName(), newSymbol,
-                                                     from);
+    return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
+                                                     newSymbol, from);
   }
 };
 
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallString.h"
 
 using namespace mlir;
@@ -17,6 +19,71 @@
   return !op->getDialect() && op->getNumRegions() == 1;
 }
 
+/// Returns the nearest symbol table from a given operation `from`. Returns
+/// nullptr if no valid parent symbol table could be found.
+static Operation *getNearestSymbolTable(Operation *from) {
+  assert(from && "expected valid operation");
+  if (isPotentiallyUnknownSymbolTable(from))
+    return nullptr;
+
+  while (!from->hasTrait<OpTrait::SymbolTable>()) {
+    from = from->getParentOp();
+
+    // Check that this is a valid op and isn't an unknown symbol table.
+    if (!from || isPotentiallyUnknownSymbolTable(from))
+      return nullptr;
+  }
+  return from;
+}
+
+/// Returns the string name of the given symbol, or None if this is not a
+/// symbol.
+static Optional<StringRef> getNameIfSymbol(Operation *symbol) {
+  auto nameAttr =
+      symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
+  return nameAttr ? nameAttr.getValue() : Optional<StringRef>();
+}
+
+/// Computes the nested symbol reference attribute for the symbol 'symbolName'
+/// that are usable within the symbol table operations from 'symbol' as far up
+/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
+/// Returns success if all references up to 'within' could be computed.
+static LogicalResult
+collectValidReferencesFor(Operation *symbol, StringRef symbolName,
+                          Operation *within,
+                          SmallVectorImpl<SymbolRefAttr> &results) {
+  assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
+  MLIRContext *ctx = symbol->getContext();
+
+  auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx);
+  results.push_back(leafRef);
+
+  // Early exit for when 'within' is the parent of 'symbol'.
+  Operation *symbolTableOp = symbol->getParentOp();
+  if (within == symbolTableOp)
+    return success();
+
+  // Collect references until 'symbolTableOp' reaches 'within'.
+  SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
+  do {
+    // Each parent of 'symbol' should define a symbol table.
+    if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
+      return failure();
+    // Each parent of 'symbol' should also be a symbol.
+    Optional<StringRef> symbolTableName = getNameIfSymbol(symbolTableOp);
+    if (!symbolTableName)
+      return failure();
+    results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx));
+
+    symbolTableOp = symbolTableOp->getParentOp();
+    if (symbolTableOp == within)
+      break;
+    nestedRefs.insert(nestedRefs.begin(),
+                      FlatSymbolRefAttr::get(*symbolTableName, ctx));
+  } while (true);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SymbolTable
 //===----------------------------------------------------------------------===//
@@ -32,11 +99,11 @@
          "expected operation to have a single block");
 
   for (auto &op : symbolTableOp->getRegion(0).front()) {
-    auto nameAttr = op.getAttrOfType<StringAttr>(getSymbolAttrName());
-    if (!nameAttr)
+    Optional<StringRef> name = getNameIfSymbol(&op);
+    if (!name)
       continue;
 
-    auto inserted = symbolTable.insert({nameAttr.getValue(), &op});
+    auto inserted = symbolTable.insert({*name, &op});
     (void)inserted;
     assert(inserted.second &&
            "expected region to contain uniquely named symbol operations");
@@ -51,13 +118,13 @@
 
 /// Erase the given symbol from the table.
 void SymbolTable::erase(Operation *symbol) {
-  auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
-  assert(nameAttr && "expected valid 'name' attribute");
+  Optional<StringRef> name = getNameIfSymbol(symbol);
+  assert(name && "expected valid 'name' attribute");
   assert(symbol->getParentOp() == symbolTableOp &&
          "expected this operation to be inside of the operation with this "
          "SymbolTable");
 
-  auto it = symbolTable.find(nameAttr.getValue());
+  auto it = symbolTable.find(*name);
   if (it != symbolTable.end() && it->second == symbol) {
     symbolTable.erase(it);
     symbol->erase();
@@ -67,9 +134,6 @@
 /// Insert a new symbol into the table and associated operation, and rename it
 /// as necessary to avoid collisions.
 void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
-  auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
-  assert(nameAttr && "expected valid 'name' attribute");
-
   auto &body = symbolTableOp->getRegion(0).front();
   if (insertPt == Block::iterator() || insertPt == body.end())
     insertPt = Block::iterator(body.getTerminator());
@@ -81,12 +145,12 @@
 
   // Add this symbol to the symbol table, uniquing the name if a conflict is
   // detected.
-  if (symbolTable.insert({nameAttr.getValue(), symbol}).second)
+  StringRef name = getSymbolName(symbol);
+  if (symbolTable.insert({name, symbol}).second)
     return;
-
   // If a conflict was detected, then the symbol will not have been added to
   // the symbol table. Try suffixes until we get to a unique name that works.
-  SmallString<128> nameBuffer(nameAttr.getValue());
+  SmallString<128> nameBuffer(name);
   unsigned originalLength = nameBuffer.size();
 
   // Iteratively try suffixes until we find one that isn't used.
@@ -95,8 +159,24 @@
     nameBuffer += '_';
     nameBuffer += std::to_string(uniquingCounter++);
   } while (!symbolTable.insert({nameBuffer, symbol}).second);
+  setSymbolName(symbol, nameBuffer);
+}
+
+/// Returns true if the given operation defines a symbol.
+bool SymbolTable::isSymbol(Operation *op) {
+  return op->hasTrait<OpTrait::Symbol>() || getNameIfSymbol(op).hasValue();
+}
+
+/// Returns the name of the given symbol operation.
+StringRef SymbolTable::getSymbolName(Operation *symbol) {
+  Optional<StringRef> name = getNameIfSymbol(symbol);
+  assert(name && "expected valid symbol name");
+  return *name;
+}
+/// Sets the name of the given symbol operation.
+void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
   symbol->setAttr(getSymbolAttrName(),
-                  StringAttr::get(nameBuffer, symbolTableOp->getContext()));
+                  StringAttr::get(name, symbol->getContext()));
 }
 
 /// Returns the operation registered with the given symbol name with the
@@ -109,30 +189,52 @@
 
   // Look for a symbol with the given name.
   for (auto &block : symbolTableOp->getRegion(0)) {
-    for (auto &op : block) {
-      auto nameAttr = op.template getAttrOfType<StringAttr>(
-          mlir::SymbolTable::getSymbolAttrName());
-      if (nameAttr && nameAttr.getValue() == symbol)
+    for (auto &op : block)
+      if (getNameIfSymbol(&op) == symbol)
         return &op;
-    }
   }
   return nullptr;
 }
+Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
+                                       SymbolRefAttr symbol) {
+  assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
+
+  // Lookup the root reference for this symbol.
+  symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference());
+  if (!symbolTableOp)
+    return nullptr;
+
+  // If there are no nested references, just return the root symbol directly.
+  ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
+  if (nestedRefs.empty())
+    return symbolTableOp;
+
+  // Verify that the root is also a symbol table.
+  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
+    return nullptr;
+
+  // Otherwise, lookup each of the nested non-leaf references and ensure that
+  // each corresponds to a valid symbol table.
+  for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
+    symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue());
+    if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
+      return nullptr;
+  }
+  return lookupSymbolIn(symbolTableOp, symbol.getLeafReference());
+}
 
 /// Returns the operation registered with the given symbol name within the
 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
 /// nullptr if no valid symbol was found.
 Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
                                                 StringRef symbol) {
-  assert(from && "expected valid operation");
-  while (!from->hasTrait<OpTrait::SymbolTable>()) {
-    from = from->getParentOp();
-
-    // Check that this is a valid op and isn't an unknown symbol table.
-    if (!from || isPotentiallyUnknownSymbolTable(from))
-      return nullptr;
-  }
-  return lookupSymbolIn(from, symbol);
+  Operation *symbolTableOp = getNearestSymbolTable(from);
+  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
+}
+Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
+                                                SymbolRefAttr symbol) {
+  Operation *symbolTableOp = getNearestSymbolTable(from);
+  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
 }
 
 //===----------------------------------------------------------------------===//
@@ -148,7 +250,7 @@
            << "Operations with a 'SymbolTable' must have exactly one block";
 
   // Check that all symbols are uniquely named within child regions.
-  llvm::StringMap<Location> nameToOrigLoc;
+  DenseMap<Attribute, Location> nameToOrigLoc;
   for (auto &block : op->getRegion(0)) {
     for (auto &op : block) {
       // Check for a symbol name attribute.
@@ -158,7 +260,7 @@
         continue;
 
       // Try to insert this symbol into the table.
-      auto it = nameToOrigLoc.try_emplace(nameAttr.getValue(), op.getLoc());
+      auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
       if (!it.second)
         return op.emitError()
             .append("redefinition of symbol named '", nameAttr.getValue(), "'")
@@ -293,6 +395,100 @@
   return WalkResult::advance();
 }
 
+/// Walks all of the symbol scopes from 'symbol' to (inclusive) 'limit' invoking
+/// the provided callback at each one with a properly scoped reference to
+/// 'symbol'. The callback takes as parameters the symbol reference at the
+/// current scope as well as the top-level operation representing the top of
+/// that scope.
+static Optional<WalkResult> walkSymbolScopes(
+    Operation *symbol, Operation *limit,
+    function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) {
+  StringRef symbolName = SymbolTable::getSymbolName(symbol);
+  assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
+
+  // Compute the ancestors of 'limit'.
+  llvm::SetVector<Operation *, SmallVector<Operation *, 4>,
+                  SmallPtrSet<Operation *, 4>>
+      limitAncestors;
+  Operation *limitAncestor = limit;
+  do {
+    // Check to see if 'symbol' is an ancestor of 'limit'.
+    if (limitAncestor == symbol) {
+      // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
+      // doesn't support parent references.
+      if (getNearestSymbolTable(limit) != symbol->getParentOp())
+        return WalkResult::advance();
+      return callback(SymbolRefAttr::get(symbolName, symbol->getContext()),
+                      limit);
+    }
+
+    limitAncestors.insert(limitAncestor);
+  } while ((limitAncestor = limitAncestor->getParentOp()));
+
+  // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
+  Operation *commonAncestor = symbol->getParentOp();
+  do {
+    if (limitAncestors.count(commonAncestor))
+      break;
+  } while ((commonAncestor = commonAncestor->getParentOp()));
+  assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
+
+  // Compute the set of valid nested references for 'symbol' as far up to the
+  // common ancestor as possible.
+  SmallVector<SymbolRefAttr, 2> references;
+  bool collectedAllReferences = succeeded(collectValidReferencesFor(
+      symbol, symbolName, commonAncestor, references));
+
+  // Handle the case where the common ancestor is 'limit'.
+  if (commonAncestor == limit) {
+    // Walk each of the ancestors of 'symbol', calling the compute function for
+    // each one.
+    Operation *limitIt = symbol->getParentOp();
+    for (size_t i = 0, e = references.size(); i != e;
+         ++i, limitIt = limitIt->getParentOp()) {
+      Optional<WalkResult> callbackResult = callback(references[i], limitIt);
+      if (callbackResult != WalkResult::advance())
+        return callbackResult;
+    }
+    return WalkResult::advance();
+  }
+
+  // Otherwise, we just need the symbol reference for 'symbol' that will be
+  // used within 'limit'. This is the last reference in the list we computed
+  // above if we were able to collect all references.
+  if (!collectedAllReferences)
+    return WalkResult::advance();
+  return callback(references.back(), limit);
+}
+
+/// Walk the symbol scopes defined by 'limit' invoking the provided callback.
+static Optional<WalkResult> walkSymbolScopes(
+    StringRef symbol, Operation *limit,
+    function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) {
+  return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit);
+}
+
+/// Returns true if the given reference 'SubRef' is a sub reference of the
+/// reference 'ref', i.e. 'ref' is a further qualified reference.
+static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
+  if (ref == subRef)
+    return true;
+
+  // If the references are not pointer equal, check to see if `subRef` is a
+  // prefix of `ref`.
+  if (ref.isa<FlatSymbolRefAttr>() ||
+      ref.getRootReference() != subRef.getRootReference())
+    return false;
+
+  auto refLeafs = ref.getNestedReferences();
+  auto subRefLeafs = subRef.getNestedReferences();
+  return subRefLeafs.size() < refLeafs.size() &&
+         subRefLeafs == refLeafs.take_front(subRefLeafs.size());
+}
+
+//===----------------------------------------------------------------------===//
+// SymbolTable::getSymbolUses
+
 /// Get an iterator range for all of the uses, for any symbol, that are nested
 /// within the given operation 'from'. This does not traverse into any nested
 /// symbol tables, and will also only return uses on 'from' if it does not
@@ -302,14 +498,35 @@
 /// tables.
 auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
   std::vector<SymbolUse> uses;
-  Optional<WalkResult> result =
-      walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) {
-        uses.push_back(symbolUse);
-        return WalkResult::advance();
-      });
+  auto walkFn = [&](SymbolUse symbolUse, ArrayRef<int>) {
+    uses.push_back(symbolUse);
+    return WalkResult::advance();
+  };
+  auto result = walkSymbolUses(from, walkFn);
   return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>();
 }
 
+//===----------------------------------------------------------------------===//
+// SymbolTable::getSymbolUses
+
+/// The implementation of SymbolTable::getSymbolUses below.
+template <typename SymbolT>
+static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
+                                                         Operation *limit) {
+  std::vector<SymbolTable::SymbolUse> uses;
+  auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) {
+    return walkSymbolUses(
+        from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+          if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()))
+            uses.push_back(symbolUse);
+          return WalkResult::advance();
+        });
+  };
+  if (walkSymbolScopes(symbol, limit, walkFn))
+    return SymbolTable::UseRange(std::move(uses));
+  return llvm::None;
+}
+
 /// Get all of the uses of the given symbol that are nested within the given
 /// operation 'from', invoking the provided callback for each. This does not
 /// traverse into any nested symbol tables, and will also only return uses on
@@ -319,16 +536,29 @@
 /// potentially be symbol tables.
 auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
     -> Optional<UseRange> {
-  SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
+  return getSymbolUsesImpl(symbol, from);
+}
+auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
+    -> Optional<UseRange> {
+  return getSymbolUsesImpl(symbol, from);
+}
 
-  std::vector<SymbolUse> uses;
-  Optional<WalkResult> result =
-      walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) {
-        if (symbolRefAttr == symbolUse.getSymbolRef())
-          uses.push_back(symbolUse);
-        return WalkResult::advance();
-      });
-  return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>();
+//===----------------------------------------------------------------------===//
+// SymbolTable::symbolKnownUseEmpty
+
+/// The implementation of SymbolTable::symbolKnownUseEmpty below.
+template <typename SymbolT>
+static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) {
+  // Walk all of the symbol uses looking for a reference to 'symbol'.
+  auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) {
+    return walkSymbolUses(
+        from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+          return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())
+                     ? WalkResult::interrupt()
+                     : WalkResult::advance();
+        });
+  };
+  return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance();
 }
 
 /// Return if the given symbol is known to have no uses that are nested within
@@ -338,35 +568,32 @@
 /// symbol table, and not the op itself. This function will also return false if
 /// there are any unknown operations that may potentially be symbol tables.
 bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) {
-  SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
-
-  // Walk all of the symbol uses looking for a reference to 'symbol'.
-  Optional<WalkResult> walkResult =
-      walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) {
-        return symbolUse.getSymbolRef() == symbolRefAttr
-                   ? WalkResult::interrupt()
-                   : WalkResult::advance();
-      });
-  return walkResult && !walkResult->wasInterrupted();
+  return symbolKnownUseEmptyImpl(symbol, from);
+}
+bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
+  return symbolKnownUseEmptyImpl(symbol, from);
 }
 
+//===----------------------------------------------------------------------===//
+// SymbolTable::replaceAllSymbolUses
+
 /// Rebuild the given attribute container after replacing all references to a
-/// symbol with `newSymAttr`.
-static Attribute rebuildAttrAfterRAUW(Attribute container,
-                                      ArrayRef<SmallVector<int, 1>> accesses,
-                                      SymbolRefAttr newSymAttr,
-                                      unsigned depth) {
+/// symbol with the updated attribute in 'accesses'.
+static Attribute rebuildAttrAfterRAUW(
+    Attribute container,
+    ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
+    unsigned depth) {
   // Given a range of Attributes, update the ones referred to by the given
   // access chains to point to the new symbol attribute.
   auto updateAttrs = [&](auto &&attrRange) {
     auto attrBegin = std::begin(attrRange);
     for (unsigned i = 0, e = accesses.size(); i != e;) {
-      ArrayRef<int> access = accesses[i];
+      ArrayRef<int> access = accesses[i].first;
       Attribute &attr = *std::next(attrBegin, access[depth]);
 
       // Check to see if this is a leaf access, i.e. a SymbolRef.
       if (access.size() == depth + 1) {
-        attr = newSymAttr;
+        attr = accesses[i].second;
         ++i;
         continue;
       }
@@ -374,12 +601,12 @@
       // Otherwise, this is a container. Collect all of the accesses for this
       // index and recurse. The recursion here is bounded by the size of the
       // largest access array.
-      auto nestedAccesses =
-          accesses.drop_front(i).take_while([&](ArrayRef<int> nextAccess) {
-            return nextAccess.size() > depth + 1 &&
-                   nextAccess[depth] == access[depth];
-          });
-      attr = rebuildAttrAfterRAUW(attr, nestedAccesses, newSymAttr, depth + 1);
+      auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
+        ArrayRef<int> nextAccess = it.first;
+        return nextAccess.size() > depth + 1 &&
+               nextAccess[depth] == access[depth];
+      });
+      attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1);
 
       // Skip over all of the accesses that refer to the nested container.
       i += nestedAccesses.size();
@@ -396,64 +623,114 @@
   return ArrayAttr::get(newAttrs, container.getContext());
 }
 
-/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
-/// provided symbol 'newSymbol' that are nested within the given operation
-/// 'from'. This does not traverse into any nested symbol tables, and will
-/// also only replace uses on 'from' if it does not also define a symbol
-/// table. This is because we treat the region as the boundary of the symbol
-/// table, and not the op itself. If there are any unknown operations that may
-/// potentially be symbol tables, no uses are replaced and failure is returned.
-LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
-                                                StringRef newSymbol,
-                                                Operation *from) {
-  SymbolRefAttr oldAttr = SymbolRefAttr::get(oldSymbol, from->getContext());
-  SymbolRefAttr newSymAttr = SymbolRefAttr::get(newSymbol, from->getContext());
+/// Generates a new symbol reference attribute with a new leaf reference.
+SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
+                                 FlatSymbolRefAttr newLeafAttr) {
+  if (oldAttr.isa<FlatSymbolRefAttr>())
+    return newLeafAttr;
+  auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
+  nestedRefs.back() = newLeafAttr;
+  return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs,
+                            oldAttr.getContext());
+}
 
+/// The implementation of SymbolTable::replaceAllSymbolUses below.
+template <typename SymbolT>
+static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
+                                              StringRef newSymbol,
+                                              Operation *limit) {
   // A collection of operations along with their new attribute dictionary.
   std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
 
-  // The current operation, and its old symbol access chains, being processed.
+  // The current operation being processed.
   Operation *curOp = nullptr;
-  SmallVector<SmallVector<int, 1>, 1> accessChains;
+
+  // The set of access chains into the attribute dictionary of the current
+  // operation, as well as the replacement attribute to use.
+  SmallVector<std::pair<SmallVector<int, 1>, SymbolRefAttr>, 1> accessChains;
 
   // Generate a new attribute dictionary for the current operation by replacing
   // references to the old symbol.
   auto generateNewAttrDict = [&] {
-    auto newAttrDict =
-        rebuildAttrAfterRAUW(curOp->getAttrList().getDictionary(), accessChains,
-                             newSymAttr, /*depth=*/0);
-    return newAttrDict.cast<DictionaryAttr>();
+    auto oldDict = curOp->getAttrList().getDictionary();
+    auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0);
+    return newDict.cast<DictionaryAttr>();
   };
 
-  // Walk the symbol uses collecting uses of the old symbol.
-  auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
-                    ArrayRef<int> accessChain) {
-    if (symbolUse.getSymbolRef() != oldAttr)
+  // Generate a new attribute to replace the given attribute.
+  MLIRContext *ctx = limit->getContext();
+  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
+  auto scopeWalkFn = [&](SymbolRefAttr oldAttr,
+                         Operation *from) -> Optional<WalkResult> {
+    SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr);
+    auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
+                      ArrayRef<int> accessChain) {
+      SymbolRefAttr useRef = symbolUse.getSymbolRef();
+      if (!isReferencePrefixOf(oldAttr, useRef))
+        return WalkResult::advance();
+
+      // If we have a valid match, check to see if this is a proper
+      // subreference. If it is, then we will need to generate a different new
+      // attribute specifically for this use.
+      SymbolRefAttr replacementRef = newAttr;
+      if (useRef != oldAttr) {
+        if (oldAttr.isa<FlatSymbolRefAttr>()) {
+          replacementRef =
+              SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
+        } else {
+          auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
+          nestedRefs[oldAttr.getNestedReferences().size() - 1] = newLeafAttr;
+          replacementRef =
+              SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
+        }
+      }
+
+      // If there was a previous operation, generate a new attribute dict
+      // for it. This means that we've finished processing the current
+      // operation, so generate a new dictionary for it.
+      if (curOp && symbolUse.getUser() != curOp) {
+        updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
+        accessChains.clear();
+      }
+
+      // Record this access.
+      curOp = symbolUse.getUser();
+      accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef});
       return WalkResult::advance();
+    };
+    if (!walkSymbolUses(from, walkFn))
+      return llvm::None;
 
-    // If there was a previous operation, generate a new attribute dict for it.
-    // This means that we've finished processing the current operation, so
-    // generate a new dictionary for it.
-    if (curOp && symbolUse.getUser() != curOp) {
+    // Check to see if we have a dangling op that needs to be processed.
+    if (curOp) {
       updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
-      accessChains.clear();
+      curOp = nullptr;
     }
-
-    // Record this access.
-    curOp = symbolUse.getUser();
-    accessChains.push_back(llvm::to_vector<1>(accessChain));
     return WalkResult::advance();
   };
-  if (!walkSymbolUses(from, walkFn))
+  if (!walkSymbolScopes(symbol, limit, scopeWalkFn))
     return failure();
 
   // Update the attribute dictionaries as necessary.
   for (auto &it : updatedAttrDicts)
     it.first->setAttrs(it.second);
-
-  // Check to see if we have a dangling op that needs to be processed.
-  if (curOp)
-    curOp->setAttrs(generateNewAttrDict());
-
   return success();
 }
+
+/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
+/// provided symbol 'newSymbol' that are nested within the given operation
+/// 'from'. This does not traverse into any nested symbol tables, and will
+/// also only replace uses on 'from' if it does not also define a symbol
+/// table. This is because we treat the region as the boundary of the symbol
+/// table, and not the op itself. If there are any unknown operations that may
+/// potentially be symbol tables, no uses are replaced and failure is returned.
+LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
+                                                StringRef newSymbol,
+                                                Operation *from) {
+  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
+}
+LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
+                                                StringRef newSymbol,
+                                                Operation *from) {
+  return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
+}
diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir
--- a/mlir/test/IR/test-symbol-rauw.mlir
+++ b/mlir/test/IR/test-symbol-rauw.mlir
@@ -32,6 +32,39 @@
 
 // -----
 
+// Check the support for nested references.
+
+// CHECK: module
+module {
+  // CHECK: module @module_a
+  module @module_a {
+    // CHECK: func @replaced_foo
+    func @foo() attributes {sym.new_name = "replaced_foo" }
+  }
+
+  // CHECK: module @replaced_module_b
+  module @module_b attributes {sym.new_name = "replaced_module_b"} {
+    // CHECK: module @replaced_module_c
+    module @module_c attributes {sym.new_name = "replaced_module_c"} {
+      // CHECK: func @replaced_foo
+      func @foo() attributes {sym.new_name = "replaced_foo" }
+    }
+  }
+
+  // CHECK: func @symbol_bar
+  func @symbol_bar() {
+    // CHECK: foo.op
+    // CHECK-SAME: use_1 = @module_a::@replaced_foo
+    // CHECK-SAME: use_2 = @replaced_module_b::@replaced_module_c::@replaced_foo
+    "foo.op"() {
+      use_1 = @module_a::@foo,
+      use_2 = @module_b::@module_c::@foo
+    } : () -> ()
+  }
+}
+
+// -----
+
 // Check that the replacement fails for potentially unknown symbol tables.
 module {
   // CHECK: func @failed_repl
diff --git a/mlir/test/IR/test-symbol-uses.mlir b/mlir/test/IR/test-symbol-uses.mlir
--- a/mlir/test/IR/test-symbol-uses.mlir
+++ b/mlir/test/IR/test-symbol-uses.mlir
@@ -4,14 +4,14 @@
 // its table.
 // expected-remark@below {{symbol_removable function successfully erased}}
 module attributes {sym.outside_use = @symbol_foo } {
-  // expected-remark@+1 {{function has 2 uses}}
+  // expected-remark@+1 {{symbol has 2 uses}}
   func @symbol_foo()
 
-  // expected-remark@below {{function has no uses}}
-  // expected-remark@below {{found use of function : @symbol_foo}}
-  // expected-remark@below {{function contains 2 nested references}}
+  // expected-remark@below {{symbol has no uses}}
+  // expected-remark@below {{found use of symbol : @symbol_foo}}
+  // expected-remark@below {{symbol contains 2 nested references}}
   func @symbol_bar() attributes {sym.use = @symbol_foo} {
-    // expected-remark@+1 {{found use of function : @symbol_foo}}
+    // expected-remark@+1 {{found use of symbol : @symbol_foo}}
     "foo.op"() {
       non_symbol_attr,
       use = [{ nested_symbol = [@symbol_foo]}],
@@ -19,13 +19,13 @@
     } : () -> ()
   }
 
-  // expected-remark@below {{function has no uses}}
+  // expected-remark@below {{symbol has no uses}}
   func @symbol_removable()
 
-  // expected-remark@+1 {{function has 1 use}}
+  // expected-remark@+1 {{symbol has 1 use}}
   func @symbol_baz()
 
-  // expected-remark@+1 {{found use of function : @symbol_baz}}
+  // expected-remark@+1 {{found use of symbol : @symbol_baz}}
   module attributes {test.reference = @symbol_baz} {
     "foo.op"() {test.nested_reference = @symbol_baz} : () -> ()
   }
@@ -33,6 +33,34 @@
 
 // -----
 
+// Test nested attribute support
+module {
+  // expected-remark@+1 {{symbol has 2 uses}}
+  module @module_b {
+    // expected-remark@+1 {{symbol has 1 uses}}
+    module @module_c {
+      // expected-remark@+1 {{symbol has 1 uses}}
+      func @foo()
+    }
+  }
+
+  // expected-remark@below {{symbol has no uses}}
+  // expected-remark@below {{symbol contains 2 nested references}}
+  func @symbol_bar() {
+    // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "foo"}}
+    // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "module_c"}}
+    // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "module_b"}}
+    // expected-remark@below {{found use of symbol : @module_b : "module_b"}}
+    "foo.op"() {
+      use_1 = [{ nested_symbol = [@module_b::@module_c::@foo]}],
+      use_2 = @module_b
+    } : () -> ()
+  }
+}
+
+
+// -----
+
 // expected-remark@+1 {{contains an unknown nested operation that 'may' define a new symbol table}}
 func @symbol_bar() {
   "foo.possibly_unknown_symbol_table"() ({
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -16,54 +16,70 @@
 /// This is a symbol test pass that tests the symbol uselist functionality
 /// provided by the symbol table along with erasing from the symbol table.
 struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
-  void runOnModule() override {
-    auto module = getModule();
-    std::vector<FuncOp> ops_to_delete;
+  WalkResult operateOnSymbol(Operation *symbol, Operation *module,
+                             SmallVectorImpl<FuncOp> &deadFunctions) {
+    // Test computing uses on a non symboltable op.
+    Optional<SymbolTable::UseRange> symbolUses =
+        SymbolTable::getSymbolUses(symbol);
 
-    for (FuncOp func : module.getOps<FuncOp>()) {
-      // Test computing uses on a non symboltable op.
-      Optional<SymbolTable::UseRange> symbolUses =
-          SymbolTable::getSymbolUses(func);
+    // Test the conservative failure case.
+    if (!symbolUses) {
+      symbol->emitRemark()
+          << "symbol contains an unknown nested operation that "
+             "'may' define a new symbol table";
+      return WalkResult::interrupt();
+    }
+    if (unsigned numUses = llvm::size(*symbolUses))
+      symbol->emitRemark() << "symbol contains " << numUses
+                           << " nested references";
 
-      // Test the conservative failure case.
-      if (!symbolUses) {
-        func.emitRemark() << "function contains an unknown nested operation "
-                             "that 'may' define a new symbol table";
-        return;
-      }
-      if (unsigned numUses = llvm::size(*symbolUses))
-        func.emitRemark() << "function contains " << numUses
-                          << " nested references";
+    // Test the functionality of symbolKnownUseEmpty.
+    if (SymbolTable::symbolKnownUseEmpty(symbol, module)) {
+      FuncOp funcSymbol = dyn_cast<FuncOp>(symbol);
+      if (funcSymbol && funcSymbol.isExternal())
+        deadFunctions.push_back(funcSymbol);
 
-      // Test the functionality of symbolKnownUseEmpty.
-      if (func.symbolKnownUseEmpty(module)) {
-        func.emitRemark() << "function has no uses";
-        if (func.getBody().empty())
-          ops_to_delete.push_back(func);
-        continue;
-      }
+      symbol->emitRemark() << "symbol has no uses";
+      return WalkResult::advance();
+    }
 
-      // Test the functionality of getSymbolUses.
-      symbolUses = func.getSymbolUses(module);
-      assert(symbolUses.hasValue() && "expected no unknown operations");
-      for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+    // Test the functionality of getSymbolUses.
+    symbolUses = SymbolTable::getSymbolUses(symbol, module);
+    assert(symbolUses.hasValue() && "expected no unknown operations");
+    for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+      // Check that we can resolve back to our symbol.
+      if (Operation *op = SymbolTable::lookupNearestSymbolFrom(
+              symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) {
         symbolUse.getUser()->emitRemark()
-            << "found use of function : " << symbolUse.getSymbolRef();
+            << "found use of symbol : " << symbolUse.getSymbolRef() << " : "
+            << symbol->getAttr(SymbolTable::getSymbolAttrName());
       }
-      func.emitRemark() << "function has " << llvm::size(*symbolUses)
-                        << " uses";
     }
+    symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses";
+    return WalkResult::advance();
+  }
+
+  void runOnModule() override {
+    auto module = getModule();
 
-    for (FuncOp func : ops_to_delete) {
+    // Walk nested symbols.
+    SmallVector<FuncOp, 4> deadFunctions;
+    module.getBodyRegion().walk([&](Operation *nestedOp) {
+      if (SymbolTable::isSymbol(nestedOp))
+        return operateOnSymbol(nestedOp, module, deadFunctions);
+      return WalkResult::advance();
+    });
+
+    for (Operation *op : deadFunctions) {
       // In order to test the SymbolTable::erase method, also erase completely
       // useless functions.
       SymbolTable table(module);
-      auto func_name = func.getName();
-      assert(table.lookup(func_name) && "expected no unknown operations");
-      table.erase(func);
-      assert(!table.lookup(func_name) &&
+      auto name = SymbolTable::getSymbolName(op);
+      assert(table.lookup(name) && "expected no unknown operations");
+      table.erase(op);
+      assert(!table.lookup(name) &&
              "expected erased operation to be unknown now");
-      module.emitRemark() << func_name << " function successfully erased";
+      module.emitRemark() << name << " function successfully erased";
     }
   }
 };
@@ -74,13 +90,15 @@
   void runOnModule() override {
     auto module = getModule();
 
-    for (FuncOp func : module.getOps<FuncOp>()) {
-      StringAttr newName = func.getAttrOfType<StringAttr>("sym.new_name");
+    // Walk nested functions and modules.
+    module.getBodyRegion().walk([&](Operation *nestedOp) {
+      StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
       if (!newName)
-        continue;
-      if (succeeded(func.replaceAllSymbolUses(newName.getValue(), module)))
-        func.setName(newName.getValue());
-    }
+        return;
+      if (succeeded(SymbolTable::replaceAllSymbolUses(
+              nestedOp, newName.getValue(), module)))
+        SymbolTable::setSymbolName(nestedOp, newName.getValue());
+    });
   }
 };
 } // end anonymous namespace