diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -50,16 +50,27 @@ // Symbol Utilities //===--------------------------------------------------------------------===// + /// Returns true if the given operation defines a symbol. + static bool isSymbol(Operation *op); + + /// Returns the name of the given symbol operation. + static StringRef getSymbolName(Operation *symbol); + /// Sets the name of the given symbol operation. + static void setSymbolName(Operation *symbol, StringRef name); + /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. static Operation *lookupSymbolIn(Operation *op, StringRef symbol); + static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); /// Returns the operation registered with the given symbol name within the /// closest parent operation of, or including, 'from' with the /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was /// found. static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); + static Operation *lookupNearestSymbolFrom(Operation *from, + SymbolRefAttr symbol); /// This class represents a specific symbol use. class SymbolUse { @@ -110,6 +121,7 @@ /// symbol table, and not the op itself. This function returns None if there /// are any unknown operations that may potentially be symbol tables. static Optional getSymbolUses(StringRef symbol, Operation *from); + static Optional getSymbolUses(Operation *symbol, Operation *from); /// Return if the given symbol is known to have no uses that are nested /// within the given operation 'from'. This does not traverse into any nested @@ -120,6 +132,7 @@ /// tables. This doesn't necessarily mean that there are no uses, we just /// can't conservatively prove it. static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); + static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); /// Attempt to replace all uses of the given symbol 'oldSymbol' with the /// provided symbol 'newSymbol' that are nested within the given operation @@ -132,6 +145,9 @@ LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Operation *from); + LLVM_NODISCARD static LogicalResult + replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, + Operation *from); private: Operation *symbolTableOp; @@ -207,14 +223,14 @@ /// operation 'from'. /// Note: See mlir::SymbolTable::getSymbolUses for more details. Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) { - return ::mlir::SymbolTable::getSymbolUses(getName(), from); + return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from); } /// Return if the current symbol is known to have no uses that are nested /// within the given operation 'from'. /// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details. bool symbolKnownUseEmpty(Operation *from) { - return ::mlir::SymbolTable::symbolKnownUseEmpty(getName(), from); + return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from); } /// Attempt to replace all uses of the current symbol with the provided symbol @@ -222,8 +238,8 @@ /// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details. LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol, Operation *from) { - return ::mlir::SymbolTable::replaceAllSymbolUses(getName(), newSymbol, - from); + return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(), + newSymbol, from); } }; diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" using namespace mlir; @@ -17,6 +19,71 @@ return !op->getDialect() && op->getNumRegions() == 1; } +/// Returns the nearest symbol table from a given operation `from`. Returns +/// nullptr if no valid parent symbol table could be found. +static Operation *getNearestSymbolTable(Operation *from) { + assert(from && "expected valid operation"); + if (isPotentiallyUnknownSymbolTable(from)) + return nullptr; + + while (!from->hasTrait()) { + from = from->getParentOp(); + + // Check that this is a valid op and isn't an unknown symbol table. + if (!from || isPotentiallyUnknownSymbolTable(from)) + return nullptr; + } + return from; +} + +/// Returns the string name of the given symbol, or None if this is not a +/// symbol. +static Optional getNameIfSymbol(Operation *symbol) { + auto nameAttr = + symbol->getAttrOfType(SymbolTable::getSymbolAttrName()); + return nameAttr ? nameAttr.getValue() : Optional(); +} + +/// Computes the nested symbol reference attribute for the symbol 'symbolName' +/// that are usable within the symbol table operations from 'symbol' as far up +/// to the given operation 'within', where 'within' is an ancestor of 'symbol'. +/// Returns success if all references up to 'within' could be computed. +static LogicalResult +collectValidReferencesFor(Operation *symbol, StringRef symbolName, + Operation *within, + SmallVectorImpl &results) { + assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor"); + MLIRContext *ctx = symbol->getContext(); + + auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx); + results.push_back(leafRef); + + // Early exit for when 'within' is the parent of 'symbol'. + Operation *symbolTableOp = symbol->getParentOp(); + if (within == symbolTableOp) + return success(); + + // Collect references until 'symbolTableOp' reaches 'within'. + SmallVector nestedRefs(1, leafRef); + do { + // Each parent of 'symbol' should define a symbol table. + if (!symbolTableOp->hasTrait()) + return failure(); + // Each parent of 'symbol' should also be a symbol. + Optional symbolTableName = getNameIfSymbol(symbolTableOp); + if (!symbolTableName) + return failure(); + results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx)); + + symbolTableOp = symbolTableOp->getParentOp(); + if (symbolTableOp == within) + break; + nestedRefs.insert(nestedRefs.begin(), + FlatSymbolRefAttr::get(*symbolTableName, ctx)); + } while (true); + return success(); +} + //===----------------------------------------------------------------------===// // SymbolTable //===----------------------------------------------------------------------===// @@ -32,11 +99,11 @@ "expected operation to have a single block"); for (auto &op : symbolTableOp->getRegion(0).front()) { - auto nameAttr = op.getAttrOfType(getSymbolAttrName()); - if (!nameAttr) + Optional name = getNameIfSymbol(&op); + if (!name) continue; - auto inserted = symbolTable.insert({nameAttr.getValue(), &op}); + auto inserted = symbolTable.insert({*name, &op}); (void)inserted; assert(inserted.second && "expected region to contain uniquely named symbol operations"); @@ -51,13 +118,13 @@ /// Erase the given symbol from the table. void SymbolTable::erase(Operation *symbol) { - auto nameAttr = symbol->getAttrOfType(getSymbolAttrName()); - assert(nameAttr && "expected valid 'name' attribute"); + Optional name = getNameIfSymbol(symbol); + assert(name && "expected valid 'name' attribute"); assert(symbol->getParentOp() == symbolTableOp && "expected this operation to be inside of the operation with this " "SymbolTable"); - auto it = symbolTable.find(nameAttr.getValue()); + auto it = symbolTable.find(*name); if (it != symbolTable.end() && it->second == symbol) { symbolTable.erase(it); symbol->erase(); @@ -67,9 +134,6 @@ /// Insert a new symbol into the table and associated operation, and rename it /// as necessary to avoid collisions. void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { - auto nameAttr = symbol->getAttrOfType(getSymbolAttrName()); - assert(nameAttr && "expected valid 'name' attribute"); - auto &body = symbolTableOp->getRegion(0).front(); if (insertPt == Block::iterator() || insertPt == body.end()) insertPt = Block::iterator(body.getTerminator()); @@ -81,12 +145,12 @@ // Add this symbol to the symbol table, uniquing the name if a conflict is // detected. - if (symbolTable.insert({nameAttr.getValue(), symbol}).second) + StringRef name = getSymbolName(symbol); + if (symbolTable.insert({name, symbol}).second) return; - // If a conflict was detected, then the symbol will not have been added to // the symbol table. Try suffixes until we get to a unique name that works. - SmallString<128> nameBuffer(nameAttr.getValue()); + SmallString<128> nameBuffer(name); unsigned originalLength = nameBuffer.size(); // Iteratively try suffixes until we find one that isn't used. @@ -95,8 +159,24 @@ nameBuffer += '_'; nameBuffer += std::to_string(uniquingCounter++); } while (!symbolTable.insert({nameBuffer, symbol}).second); + setSymbolName(symbol, nameBuffer); +} + +/// Returns true if the given operation defines a symbol. +bool SymbolTable::isSymbol(Operation *op) { + return op->hasTrait() || getNameIfSymbol(op).hasValue(); +} + +/// Returns the name of the given symbol operation. +StringRef SymbolTable::getSymbolName(Operation *symbol) { + Optional name = getNameIfSymbol(symbol); + assert(name && "expected valid symbol name"); + return *name; +} +/// Sets the name of the given symbol operation. +void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { symbol->setAttr(getSymbolAttrName(), - StringAttr::get(nameBuffer, symbolTableOp->getContext())); + StringAttr::get(name, symbol->getContext())); } /// Returns the operation registered with the given symbol name with the @@ -109,30 +189,52 @@ // Look for a symbol with the given name. for (auto &block : symbolTableOp->getRegion(0)) { - for (auto &op : block) { - auto nameAttr = op.template getAttrOfType( - mlir::SymbolTable::getSymbolAttrName()); - if (nameAttr && nameAttr.getValue() == symbol) + for (auto &op : block) + if (getNameIfSymbol(&op) == symbol) return &op; - } } return nullptr; } +Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, + SymbolRefAttr symbol) { + assert(symbolTableOp->hasTrait()); + + // Lookup the root reference for this symbol. + symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference()); + if (!symbolTableOp) + return nullptr; + + // If there are no nested references, just return the root symbol directly. + ArrayRef nestedRefs = symbol.getNestedReferences(); + if (nestedRefs.empty()) + return symbolTableOp; + + // Verify that the root is also a symbol table. + if (!symbolTableOp->hasTrait()) + return nullptr; + + // Otherwise, lookup each of the nested non-leaf references and ensure that + // each corresponds to a valid symbol table. + for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { + symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue()); + if (!symbolTableOp || !symbolTableOp->hasTrait()) + return nullptr; + } + return lookupSymbolIn(symbolTableOp, symbol.getLeafReference()); +} /// Returns the operation registered with the given symbol name within the /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns /// nullptr if no valid symbol was found. Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, StringRef symbol) { - assert(from && "expected valid operation"); - while (!from->hasTrait()) { - from = from->getParentOp(); - - // Check that this is a valid op and isn't an unknown symbol table. - if (!from || isPotentiallyUnknownSymbolTable(from)) - return nullptr; - } - return lookupSymbolIn(from, symbol); + Operation *symbolTableOp = getNearestSymbolTable(from); + return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; +} +Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, + SymbolRefAttr symbol) { + Operation *symbolTableOp = getNearestSymbolTable(from); + return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; } //===----------------------------------------------------------------------===// @@ -148,7 +250,7 @@ << "Operations with a 'SymbolTable' must have exactly one block"; // Check that all symbols are uniquely named within child regions. - llvm::StringMap nameToOrigLoc; + DenseMap nameToOrigLoc; for (auto &block : op->getRegion(0)) { for (auto &op : block) { // Check for a symbol name attribute. @@ -158,7 +260,7 @@ continue; // Try to insert this symbol into the table. - auto it = nameToOrigLoc.try_emplace(nameAttr.getValue(), op.getLoc()); + auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc()); if (!it.second) return op.emitError() .append("redefinition of symbol named '", nameAttr.getValue(), "'") @@ -293,6 +395,100 @@ return WalkResult::advance(); } +/// Walks all of the symbol scopes from 'symbol' to (inclusive) 'limit' invoking +/// the provided callback at each one with a properly scoped reference to +/// 'symbol'. The callback takes as parameters the symbol reference at the +/// current scope as well as the top-level operation representing the top of +/// that scope. +static Optional walkSymbolScopes( + Operation *symbol, Operation *limit, + function_ref(SymbolRefAttr, Operation *)> callback) { + StringRef symbolName = SymbolTable::getSymbolName(symbol); + assert(!symbol->hasTrait() || symbol != limit); + + // Compute the ancestors of 'limit'. + llvm::SetVector, + SmallPtrSet> + limitAncestors; + Operation *limitAncestor = limit; + do { + // Check to see if 'symbol' is an ancestor of 'limit'. + if (limitAncestor == symbol) { + // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr + // doesn't support parent references. + if (getNearestSymbolTable(limit) != symbol->getParentOp()) + return WalkResult::advance(); + return callback(SymbolRefAttr::get(symbolName, symbol->getContext()), + limit); + } + + limitAncestors.insert(limitAncestor); + } while ((limitAncestor = limitAncestor->getParentOp())); + + // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'. + Operation *commonAncestor = symbol->getParentOp(); + do { + if (limitAncestors.count(commonAncestor)) + break; + } while ((commonAncestor = commonAncestor->getParentOp())); + assert(commonAncestor && "'limit' and 'symbol' have no common ancestor"); + + // Compute the set of valid nested references for 'symbol' as far up to the + // common ancestor as possible. + SmallVector references; + bool collectedAllReferences = succeeded(collectValidReferencesFor( + symbol, symbolName, commonAncestor, references)); + + // Handle the case where the common ancestor is 'limit'. + if (commonAncestor == limit) { + // Walk each of the ancestors of 'symbol', calling the compute function for + // each one. + Operation *limitIt = symbol->getParentOp(); + for (size_t i = 0, e = references.size(); i != e; + ++i, limitIt = limitIt->getParentOp()) { + Optional callbackResult = callback(references[i], limitIt); + if (callbackResult != WalkResult::advance()) + return callbackResult; + } + return WalkResult::advance(); + } + + // Otherwise, we just need the symbol reference for 'symbol' that will be + // used within 'limit'. This is the last reference in the list we computed + // above if we were able to collect all references. + if (!collectedAllReferences) + return WalkResult::advance(); + return callback(references.back(), limit); +} + +/// Walk the symbol scopes defined by 'limit' invoking the provided callback. +static Optional walkSymbolScopes( + StringRef symbol, Operation *limit, + function_ref(SymbolRefAttr, Operation *)> callback) { + return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit); +} + +/// Returns true if the given reference 'SubRef' is a sub reference of the +/// reference 'ref', i.e. 'ref' is a further qualified reference. +static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) { + if (ref == subRef) + return true; + + // If the references are not pointer equal, check to see if `subRef` is a + // prefix of `ref`. + if (ref.isa() || + ref.getRootReference() != subRef.getRootReference()) + return false; + + auto refLeafs = ref.getNestedReferences(); + auto subRefLeafs = subRef.getNestedReferences(); + return subRefLeafs.size() < refLeafs.size() && + subRefLeafs == refLeafs.take_front(subRefLeafs.size()); +} + +//===----------------------------------------------------------------------===// +// SymbolTable::getSymbolUses + /// Get an iterator range for all of the uses, for any symbol, that are nested /// within the given operation 'from'. This does not traverse into any nested /// symbol tables, and will also only return uses on 'from' if it does not @@ -302,14 +498,35 @@ /// tables. auto SymbolTable::getSymbolUses(Operation *from) -> Optional { std::vector uses; - Optional result = - walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef) { - uses.push_back(symbolUse); - return WalkResult::advance(); - }); + auto walkFn = [&](SymbolUse symbolUse, ArrayRef) { + uses.push_back(symbolUse); + return WalkResult::advance(); + }; + auto result = walkSymbolUses(from, walkFn); return result ? Optional(std::move(uses)) : Optional(); } +//===----------------------------------------------------------------------===// +// SymbolTable::getSymbolUses + +/// The implementation of SymbolTable::getSymbolUses below. +template +static Optional getSymbolUsesImpl(SymbolT symbol, + Operation *limit) { + std::vector uses; + auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { + return walkSymbolUses( + from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { + if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())) + uses.push_back(symbolUse); + return WalkResult::advance(); + }); + }; + if (walkSymbolScopes(symbol, limit, walkFn)) + return SymbolTable::UseRange(std::move(uses)); + return llvm::None; +} + /// Get all of the uses of the given symbol that are nested within the given /// operation 'from', invoking the provided callback for each. This does not /// traverse into any nested symbol tables, and will also only return uses on @@ -319,16 +536,29 @@ /// potentially be symbol tables. auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from) -> Optional { - SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext()); + return getSymbolUsesImpl(symbol, from); +} +auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from) + -> Optional { + return getSymbolUsesImpl(symbol, from); +} - std::vector uses; - Optional result = - walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef) { - if (symbolRefAttr == symbolUse.getSymbolRef()) - uses.push_back(symbolUse); - return WalkResult::advance(); - }); - return result ? Optional(std::move(uses)) : Optional(); +//===----------------------------------------------------------------------===// +// SymbolTable::symbolKnownUseEmpty + +/// The implementation of SymbolTable::symbolKnownUseEmpty below. +template +static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) { + // Walk all of the symbol uses looking for a reference to 'symbol'. + auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { + return walkSymbolUses( + from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { + return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()) + ? WalkResult::interrupt() + : WalkResult::advance(); + }); + }; + return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance(); } /// Return if the given symbol is known to have no uses that are nested within @@ -338,35 +568,32 @@ /// symbol table, and not the op itself. This function will also return false if /// there are any unknown operations that may potentially be symbol tables. bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) { - SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext()); - - // Walk all of the symbol uses looking for a reference to 'symbol'. - Optional walkResult = - walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef) { - return symbolUse.getSymbolRef() == symbolRefAttr - ? WalkResult::interrupt() - : WalkResult::advance(); - }); - return walkResult && !walkResult->wasInterrupted(); + return symbolKnownUseEmptyImpl(symbol, from); +} +bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { + return symbolKnownUseEmptyImpl(symbol, from); } +//===----------------------------------------------------------------------===// +// SymbolTable::replaceAllSymbolUses + /// Rebuild the given attribute container after replacing all references to a -/// symbol with `newSymAttr`. -static Attribute rebuildAttrAfterRAUW(Attribute container, - ArrayRef> accesses, - SymbolRefAttr newSymAttr, - unsigned depth) { +/// symbol with the updated attribute in 'accesses'. +static Attribute rebuildAttrAfterRAUW( + Attribute container, + ArrayRef, SymbolRefAttr>> accesses, + unsigned depth) { // Given a range of Attributes, update the ones referred to by the given // access chains to point to the new symbol attribute. auto updateAttrs = [&](auto &&attrRange) { auto attrBegin = std::begin(attrRange); for (unsigned i = 0, e = accesses.size(); i != e;) { - ArrayRef access = accesses[i]; + ArrayRef access = accesses[i].first; Attribute &attr = *std::next(attrBegin, access[depth]); // Check to see if this is a leaf access, i.e. a SymbolRef. if (access.size() == depth + 1) { - attr = newSymAttr; + attr = accesses[i].second; ++i; continue; } @@ -374,12 +601,12 @@ // Otherwise, this is a container. Collect all of the accesses for this // index and recurse. The recursion here is bounded by the size of the // largest access array. - auto nestedAccesses = - accesses.drop_front(i).take_while([&](ArrayRef nextAccess) { - return nextAccess.size() > depth + 1 && - nextAccess[depth] == access[depth]; - }); - attr = rebuildAttrAfterRAUW(attr, nestedAccesses, newSymAttr, depth + 1); + auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) { + ArrayRef nextAccess = it.first; + return nextAccess.size() > depth + 1 && + nextAccess[depth] == access[depth]; + }); + attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1); // Skip over all of the accesses that refer to the nested container. i += nestedAccesses.size(); @@ -396,64 +623,114 @@ return ArrayAttr::get(newAttrs, container.getContext()); } -/// Attempt to replace all uses of the given symbol 'oldSymbol' with the -/// provided symbol 'newSymbol' that are nested within the given operation -/// 'from'. This does not traverse into any nested symbol tables, and will -/// also only replace uses on 'from' if it does not also define a symbol -/// table. This is because we treat the region as the boundary of the symbol -/// table, and not the op itself. If there are any unknown operations that may -/// potentially be symbol tables, no uses are replaced and failure is returned. -LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, - StringRef newSymbol, - Operation *from) { - SymbolRefAttr oldAttr = SymbolRefAttr::get(oldSymbol, from->getContext()); - SymbolRefAttr newSymAttr = SymbolRefAttr::get(newSymbol, from->getContext()); +/// Generates a new symbol reference attribute with a new leaf reference. +SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, + FlatSymbolRefAttr newLeafAttr) { + if (oldAttr.isa()) + return newLeafAttr; + auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); + nestedRefs.back() = newLeafAttr; + return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs, + oldAttr.getContext()); +} +/// The implementation of SymbolTable::replaceAllSymbolUses below. +template +static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, + StringRef newSymbol, + Operation *limit) { // A collection of operations along with their new attribute dictionary. std::vector> updatedAttrDicts; - // The current operation, and its old symbol access chains, being processed. + // The current operation being processed. Operation *curOp = nullptr; - SmallVector, 1> accessChains; + + // The set of access chains into the attribute dictionary of the current + // operation, as well as the replacement attribute to use. + SmallVector, SymbolRefAttr>, 1> accessChains; // Generate a new attribute dictionary for the current operation by replacing // references to the old symbol. auto generateNewAttrDict = [&] { - auto newAttrDict = - rebuildAttrAfterRAUW(curOp->getAttrList().getDictionary(), accessChains, - newSymAttr, /*depth=*/0); - return newAttrDict.cast(); + auto oldDict = curOp->getAttrList().getDictionary(); + auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0); + return newDict.cast(); }; - // Walk the symbol uses collecting uses of the old symbol. - auto walkFn = [&](SymbolTable::SymbolUse symbolUse, - ArrayRef accessChain) { - if (symbolUse.getSymbolRef() != oldAttr) + // Generate a new attribute to replace the given attribute. + MLIRContext *ctx = limit->getContext(); + FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx); + auto scopeWalkFn = [&](SymbolRefAttr oldAttr, + Operation *from) -> Optional { + SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr); + auto walkFn = [&](SymbolTable::SymbolUse symbolUse, + ArrayRef accessChain) { + SymbolRefAttr useRef = symbolUse.getSymbolRef(); + if (!isReferencePrefixOf(oldAttr, useRef)) + return WalkResult::advance(); + + // If we have a valid match, check to see if this is a proper + // subreference. If it is, then we will need to generate a different new + // attribute specifically for this use. + SymbolRefAttr replacementRef = newAttr; + if (useRef != oldAttr) { + if (oldAttr.isa()) { + replacementRef = + SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx); + } else { + auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); + nestedRefs[oldAttr.getNestedReferences().size() - 1] = newLeafAttr; + replacementRef = + SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx); + } + } + + // If there was a previous operation, generate a new attribute dict + // for it. This means that we've finished processing the current + // operation, so generate a new dictionary for it. + if (curOp && symbolUse.getUser() != curOp) { + updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); + accessChains.clear(); + } + + // Record this access. + curOp = symbolUse.getUser(); + accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef}); return WalkResult::advance(); + }; + if (!walkSymbolUses(from, walkFn)) + return llvm::None; - // If there was a previous operation, generate a new attribute dict for it. - // This means that we've finished processing the current operation, so - // generate a new dictionary for it. - if (curOp && symbolUse.getUser() != curOp) { + // Check to see if we have a dangling op that needs to be processed. + if (curOp) { updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); - accessChains.clear(); + curOp = nullptr; } - - // Record this access. - curOp = symbolUse.getUser(); - accessChains.push_back(llvm::to_vector<1>(accessChain)); return WalkResult::advance(); }; - if (!walkSymbolUses(from, walkFn)) + if (!walkSymbolScopes(symbol, limit, scopeWalkFn)) return failure(); // Update the attribute dictionaries as necessary. for (auto &it : updatedAttrDicts) it.first->setAttrs(it.second); - - // Check to see if we have a dangling op that needs to be processed. - if (curOp) - curOp->setAttrs(generateNewAttrDict()); - return success(); } + +/// Attempt to replace all uses of the given symbol 'oldSymbol' with the +/// provided symbol 'newSymbol' that are nested within the given operation +/// 'from'. This does not traverse into any nested symbol tables, and will +/// also only replace uses on 'from' if it does not also define a symbol +/// table. This is because we treat the region as the boundary of the symbol +/// table, and not the op itself. If there are any unknown operations that may +/// potentially be symbol tables, no uses are replaced and failure is returned. +LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, + StringRef newSymbol, + Operation *from) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); +} +LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, + StringRef newSymbol, + Operation *from) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); +} diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir --- a/mlir/test/IR/test-symbol-rauw.mlir +++ b/mlir/test/IR/test-symbol-rauw.mlir @@ -32,6 +32,39 @@ // ----- +// Check the support for nested references. + +// CHECK: module +module { + // CHECK: module @module_a + module @module_a { + // CHECK: func @replaced_foo + func @foo() attributes {sym.new_name = "replaced_foo" } + } + + // CHECK: module @replaced_module_b + module @module_b attributes {sym.new_name = "replaced_module_b"} { + // CHECK: module @replaced_module_c + module @module_c attributes {sym.new_name = "replaced_module_c"} { + // CHECK: func @replaced_foo + func @foo() attributes {sym.new_name = "replaced_foo" } + } + } + + // CHECK: func @symbol_bar + func @symbol_bar() { + // CHECK: foo.op + // CHECK-SAME: use_1 = @module_a::@replaced_foo + // CHECK-SAME: use_2 = @replaced_module_b::@replaced_module_c::@replaced_foo + "foo.op"() { + use_1 = @module_a::@foo, + use_2 = @module_b::@module_c::@foo + } : () -> () + } +} + +// ----- + // Check that the replacement fails for potentially unknown symbol tables. module { // CHECK: func @failed_repl diff --git a/mlir/test/IR/test-symbol-uses.mlir b/mlir/test/IR/test-symbol-uses.mlir --- a/mlir/test/IR/test-symbol-uses.mlir +++ b/mlir/test/IR/test-symbol-uses.mlir @@ -4,14 +4,14 @@ // its table. // expected-remark@below {{symbol_removable function successfully erased}} module attributes {sym.outside_use = @symbol_foo } { - // expected-remark@+1 {{function has 2 uses}} + // expected-remark@+1 {{symbol has 2 uses}} func @symbol_foo() - // expected-remark@below {{function has no uses}} - // expected-remark@below {{found use of function : @symbol_foo}} - // expected-remark@below {{function contains 2 nested references}} + // expected-remark@below {{symbol has no uses}} + // expected-remark@below {{found use of symbol : @symbol_foo}} + // expected-remark@below {{symbol contains 2 nested references}} func @symbol_bar() attributes {sym.use = @symbol_foo} { - // expected-remark@+1 {{found use of function : @symbol_foo}} + // expected-remark@+1 {{found use of symbol : @symbol_foo}} "foo.op"() { non_symbol_attr, use = [{ nested_symbol = [@symbol_foo]}], @@ -19,13 +19,13 @@ } : () -> () } - // expected-remark@below {{function has no uses}} + // expected-remark@below {{symbol has no uses}} func @symbol_removable() - // expected-remark@+1 {{function has 1 use}} + // expected-remark@+1 {{symbol has 1 use}} func @symbol_baz() - // expected-remark@+1 {{found use of function : @symbol_baz}} + // expected-remark@+1 {{found use of symbol : @symbol_baz}} module attributes {test.reference = @symbol_baz} { "foo.op"() {test.nested_reference = @symbol_baz} : () -> () } @@ -33,6 +33,34 @@ // ----- +// Test nested attribute support +module { + // expected-remark@+1 {{symbol has 2 uses}} + module @module_b { + // expected-remark@+1 {{symbol has 1 uses}} + module @module_c { + // expected-remark@+1 {{symbol has 1 uses}} + func @foo() + } + } + + // expected-remark@below {{symbol has no uses}} + // expected-remark@below {{symbol contains 2 nested references}} + func @symbol_bar() { + // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "foo"}} + // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "module_c"}} + // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "module_b"}} + // expected-remark@below {{found use of symbol : @module_b : "module_b"}} + "foo.op"() { + use_1 = [{ nested_symbol = [@module_b::@module_c::@foo]}], + use_2 = @module_b + } : () -> () + } +} + + +// ----- + // expected-remark@+1 {{contains an unknown nested operation that 'may' define a new symbol table}} func @symbol_bar() { "foo.possibly_unknown_symbol_table"() ({ diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -16,54 +16,70 @@ /// This is a symbol test pass that tests the symbol uselist functionality /// provided by the symbol table along with erasing from the symbol table. struct SymbolUsesPass : public ModulePass { - void runOnModule() override { - auto module = getModule(); - std::vector ops_to_delete; + WalkResult operateOnSymbol(Operation *symbol, Operation *module, + SmallVectorImpl &deadFunctions) { + // Test computing uses on a non symboltable op. + Optional symbolUses = + SymbolTable::getSymbolUses(symbol); - for (FuncOp func : module.getOps()) { - // Test computing uses on a non symboltable op. - Optional symbolUses = - SymbolTable::getSymbolUses(func); + // Test the conservative failure case. + if (!symbolUses) { + symbol->emitRemark() + << "symbol contains an unknown nested operation that " + "'may' define a new symbol table"; + return WalkResult::interrupt(); + } + if (unsigned numUses = llvm::size(*symbolUses)) + symbol->emitRemark() << "symbol contains " << numUses + << " nested references"; - // Test the conservative failure case. - if (!symbolUses) { - func.emitRemark() << "function contains an unknown nested operation " - "that 'may' define a new symbol table"; - return; - } - if (unsigned numUses = llvm::size(*symbolUses)) - func.emitRemark() << "function contains " << numUses - << " nested references"; + // Test the functionality of symbolKnownUseEmpty. + if (SymbolTable::symbolKnownUseEmpty(symbol, module)) { + FuncOp funcSymbol = dyn_cast(symbol); + if (funcSymbol && funcSymbol.isExternal()) + deadFunctions.push_back(funcSymbol); - // Test the functionality of symbolKnownUseEmpty. - if (func.symbolKnownUseEmpty(module)) { - func.emitRemark() << "function has no uses"; - if (func.getBody().empty()) - ops_to_delete.push_back(func); - continue; - } + symbol->emitRemark() << "symbol has no uses"; + return WalkResult::advance(); + } - // Test the functionality of getSymbolUses. - symbolUses = func.getSymbolUses(module); - assert(symbolUses.hasValue() && "expected no unknown operations"); - for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + // Test the functionality of getSymbolUses. + symbolUses = SymbolTable::getSymbolUses(symbol, module); + assert(symbolUses.hasValue() && "expected no unknown operations"); + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + // Check that we can resolve back to our symbol. + if (Operation *op = SymbolTable::lookupNearestSymbolFrom( + symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) { symbolUse.getUser()->emitRemark() - << "found use of function : " << symbolUse.getSymbolRef(); + << "found use of symbol : " << symbolUse.getSymbolRef() << " : " + << symbol->getAttr(SymbolTable::getSymbolAttrName()); } - func.emitRemark() << "function has " << llvm::size(*symbolUses) - << " uses"; } + symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses"; + return WalkResult::advance(); + } + + void runOnModule() override { + auto module = getModule(); - for (FuncOp func : ops_to_delete) { + // Walk nested symbols. + SmallVector deadFunctions; + module.getBodyRegion().walk([&](Operation *nestedOp) { + if (SymbolTable::isSymbol(nestedOp)) + return operateOnSymbol(nestedOp, module, deadFunctions); + return WalkResult::advance(); + }); + + for (Operation *op : deadFunctions) { // In order to test the SymbolTable::erase method, also erase completely // useless functions. SymbolTable table(module); - auto func_name = func.getName(); - assert(table.lookup(func_name) && "expected no unknown operations"); - table.erase(func); - assert(!table.lookup(func_name) && + auto name = SymbolTable::getSymbolName(op); + assert(table.lookup(name) && "expected no unknown operations"); + table.erase(op); + assert(!table.lookup(name) && "expected erased operation to be unknown now"); - module.emitRemark() << func_name << " function successfully erased"; + module.emitRemark() << name << " function successfully erased"; } } }; @@ -74,13 +90,15 @@ void runOnModule() override { auto module = getModule(); - for (FuncOp func : module.getOps()) { - StringAttr newName = func.getAttrOfType("sym.new_name"); + // Walk nested functions and modules. + module.getBodyRegion().walk([&](Operation *nestedOp) { + StringAttr newName = nestedOp->getAttrOfType("sym.new_name"); if (!newName) - continue; - if (succeeded(func.replaceAllSymbolUses(newName.getValue(), module))) - func.setName(newName.getValue()); - } + return; + if (succeeded(SymbolTable::replaceAllSymbolUses( + nestedOp, newName.getValue(), module))) + SymbolTable::setSymbolName(nestedOp, newName.getValue()); + }); } }; } // end anonymous namespace