diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -137,47 +137,47 @@ /// Get an iterator range for all of the uses, for any symbol, that are nested /// within the given operation 'from'. This does not traverse into any nested - /// symbol tables, and will also only return uses on 'from' if it does not - /// also define a symbol table. This is because we treat the region as the - /// boundary of the symbol table, and not the op itself. This function returns - /// None if there are any unknown operations that may potentially be symbol - /// tables. + /// symbol tables. This function returns None if there are any unknown + /// operations that may potentially be symbol tables. static Optional getSymbolUses(Operation *from); + static Optional getSymbolUses(Region *from); /// Get all of the uses of the given symbol that are nested within the given - /// operation 'from'. This does not traverse into any nested symbol tables, - /// and will also only return uses on 'from' if it does not also define a - /// symbol table. This is because we treat the region as the boundary of the - /// symbol table, and not the op itself. This function returns None if there - /// are any unknown operations that may potentially be symbol tables. + /// operation 'from'. This does not traverse into any nested symbol tables. + /// This function returns None if there are any unknown operations that may + /// potentially be symbol tables. static Optional getSymbolUses(StringRef symbol, Operation *from); static Optional getSymbolUses(Operation *symbol, Operation *from); + static Optional getSymbolUses(StringRef symbol, Region *from); + static Optional getSymbolUses(Operation *symbol, Region *from); /// Return if the given symbol is known to have no uses that are nested /// within the given operation 'from'. This does not traverse into any nested - /// symbol tables, and will also only count uses on 'from' if it does not also - /// define a symbol table. This is because we treat the region as the boundary - /// of the symbol table, and not the op itself. This function will also return - /// false if there are any unknown operations that may potentially be symbol - /// tables. This doesn't necessarily mean that there are no uses, we just - /// can't conservatively prove it. + /// symbol tables. This function will also return false if there are any + /// unknown operations that may potentially be symbol tables. This doesn't + /// necessarily mean that there are no uses, we just can't conservatively + /// prove it. static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); + static bool symbolKnownUseEmpty(StringRef symbol, Region *from); + static bool symbolKnownUseEmpty(Operation *symbol, Region *from); /// Attempt to replace all uses of the given symbol 'oldSymbol' with the /// provided symbol 'newSymbol' that are nested within the given operation - /// 'from'. This does not traverse into any nested symbol tables, and will - /// also only replace uses on 'from' if it does not also define a symbol - /// table. This is because we treat the region as the boundary of the symbol - /// table, and not the op itself. If there are any unknown operations that may - /// potentially be symbol tables, no uses are replaced and failure is - /// returned. + /// 'from'. This does not traverse into any nested symbol tables. If there are + /// any unknown operations that may potentially be symbol tables, no uses are + /// replaced and failure is returned. LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Operation *from); LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, Operation *from); + LLVM_NODISCARD static LogicalResult + replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Region *from); + LLVM_NODISCARD static LogicalResult + replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, + Region *from); private: Operation *symbolTableOp; diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -401,35 +401,19 @@ } /// Walk all of the uses, for any symbol, that are nested within the given -/// operation 'from', invoking the provided callback for each. This does not -/// traverse into any nested symbol tables, and will also only return uses on -/// 'from' if it does not also define a symbol table. +/// regions, invoking the provided callback for each. This does not traverse +/// into any nested symbol tables. static Optional walkSymbolUses( - Operation *from, + MutableArrayRef regions, function_ref)> callback) { - // If from is not a symbol table, check for uses. A symbol table defines a new - // scope, so we can't walk the attributes from the symbol table op. - if (!from->hasTrait()) { - if (walkSymbolRefs(from, callback).wasInterrupted()) - return WalkResult::interrupt(); - } - - SmallVector worklist; - worklist.reserve(from->getNumRegions()); - for (Region ®ion : from->getRegions()) - worklist.push_back(®ion); - + SmallVector worklist(llvm::make_pointer_range(regions)); while (!worklist.empty()) { - Region *region = worklist.pop_back_val(); - for (Block &block : *region) { + for (Block &block : *worklist.pop_back_val()) { for (Operation &op : block) { if (walkSymbolRefs(&op, callback).wasInterrupted()) return WalkResult::interrupt(); - // If this operation has regions, and it as well as its dialect aren't - // registered then conservatively fail. The operation may define a - // symbol table, so we can't opaquely know if we should traverse to find - // nested uses. + // Check that this isn't a potentially unknown symbol table. if (isPotentiallyUnknownSymbolTable(&op)) return llvm::None; @@ -444,16 +428,74 @@ } return WalkResult::advance(); } +/// Walk all of the uses, for any symbol, that are nested within the given +/// operaion 'from', invoking the provided callback for each. This does not +/// traverse into any nested symbol tables. +static Optional walkSymbolUses( + Operation *from, + function_ref)> callback) { + // If this operation has regions, and it, as well as its dialect, isn't + // registered then conservatively fail. The operation may define a + // symbol table, so we can't opaquely know if we should traverse to find + // nested uses. + if (isPotentiallyUnknownSymbolTable(from)) + return llvm::None; + + // Walk the uses on this operation. + if (walkSymbolRefs(from, callback).wasInterrupted()) + return WalkResult::interrupt(); + + // Only recurse if this operation is not a symbol table. A symbol table + // defines a new scope, so we can't walk the attributes from within the symbol + // table op. + if (!from->hasTrait()) + return walkSymbolUses(from->getRegions(), callback); + return WalkResult::advance(); +} + +namespace { +/// This class represents a single symbol scope. A symbol scope represents the +/// set of operations nested within a symbol table that may reference symbols +/// within that table. A symbol scope does not contain the symbol table +/// operation itself, just its contained operations. A scope ends at leaf +/// operations or another symbol table operation. +struct SymbolScope { + /// Walk the symbol uses within this scope, invoking the given callback. + /// This variant is used when the callback type matches that expected by + /// 'walkSymbolUses'. + template ::result_t, void>::value> * = + nullptr> + Optional walk(CallbackT cback) { + if (Region *region = limit.dyn_cast()) + return walkSymbolUses(*region, cback); + return walkSymbolUses(limit.get(), cback); + } + /// This variant is used when the callback type matches a stripped down type: + /// void(SymbolTable::SymbolUse use) + template ::result_t, void>::value> * = + nullptr> + Optional walk(CallbackT cback) { + return walk([=](SymbolTable::SymbolUse use, ArrayRef) { + return cback(use), WalkResult::advance(); + }); + } -/// Walks all of the symbol scopes from 'symbol' to (inclusive) 'limit' invoking -/// the provided callback at each one with a properly scoped reference to -/// 'symbol'. The callback takes as parameters the symbol reference at the -/// current scope as well as the top-level operation representing the top of -/// that scope. -static Optional walkSymbolScopes( - Operation *symbol, Operation *limit, - function_ref(SymbolRefAttr, Operation *)> callback) { - StringRef symbolName = SymbolTable::getSymbolName(symbol); + /// The representation of the symbol within this scope. + SymbolRefAttr symbol; + + /// The IR unit representing this scope. + llvm::PointerUnion limit; +}; +} // end anonymous namespace + +/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'. +static SmallVector collectSymbolScopes(Operation *symbol, + Operation *limit) { + StringRef symName = SymbolTable::getSymbolName(symbol); assert(!symbol->hasTrait() || symbol != limit); // Compute the ancestors of 'limit'. @@ -466,10 +508,10 @@ if (limitAncestor == symbol) { // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr // doesn't support parent references. - if (SymbolTable::getNearestSymbolTable(limit) != symbol->getParentOp()) - return WalkResult::advance(); - return callback(SymbolRefAttr::get(symbolName, symbol->getContext()), - limit); + if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == + symbol->getParentOp()) + return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}}; + return {}; } limitAncestors.insert(limitAncestor); @@ -486,36 +528,45 @@ // Compute the set of valid nested references for 'symbol' as far up to the // common ancestor as possible. SmallVector references; - bool collectedAllReferences = succeeded(collectValidReferencesFor( - symbol, symbolName, commonAncestor, references)); + bool collectedAllReferences = succeeded( + collectValidReferencesFor(symbol, symName, commonAncestor, references)); // Handle the case where the common ancestor is 'limit'. if (commonAncestor == limit) { + SmallVector scopes; + // Walk each of the ancestors of 'symbol', calling the compute function for // each one. Operation *limitIt = symbol->getParentOp(); for (size_t i = 0, e = references.size(); i != e; ++i, limitIt = limitIt->getParentOp()) { - Optional callbackResult = callback(references[i], limitIt); - if (callbackResult != WalkResult::advance()) - return callbackResult; + assert(limitIt->hasTrait()); + scopes.push_back({references[i], &limitIt->getRegion(0)}); } - return WalkResult::advance(); + return scopes; } // Otherwise, we just need the symbol reference for 'symbol' that will be // used within 'limit'. This is the last reference in the list we computed // above if we were able to collect all references. if (!collectedAllReferences) - return WalkResult::advance(); - return callback(references.back(), limit); + return {}; + return {{references.back(), limit}}; } +static SmallVector collectSymbolScopes(Operation *symbol, + Region *limit) { + auto scopes = collectSymbolScopes(symbol, limit->getParentOp()); -/// Walk the symbol scopes defined by 'limit' invoking the provided callback. -static Optional walkSymbolScopes( - StringRef symbol, Operation *limit, - function_ref(SymbolRefAttr, Operation *)> callback) { - return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit); + // If we collected some scopes to walk, make sure to constrain the one for + // limit to the specific region requested. + if (!scopes.empty()) + scopes.back().limit = limit; + return scopes; +} +template +static SmallVector collectSymbolScopes(StringRef symbol, + IRUnit *limit) { + return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}}; } /// Returns true if the given reference 'SubRef' is a sub reference of the @@ -539,6 +590,18 @@ //===----------------------------------------------------------------------===// // SymbolTable::getSymbolUses +/// The implementation of SymbolTable::getSymbolUses below. +template +static Optional getSymbolUsesImpl(FromT from) { + std::vector uses; + auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { + uses.push_back(symbolUse); + return WalkResult::advance(); + }; + auto result = walkSymbolUses(from, walkFn); + return result ? Optional(std::move(uses)) : llvm::None; +} + /// Get an iterator range for all of the uses, for any symbol, that are nested /// within the given operation 'from'. This does not traverse into any nested /// symbol tables, and will also only return uses on 'from' if it does not @@ -547,43 +610,34 @@ /// None if there are any unknown operations that may potentially be symbol /// tables. auto SymbolTable::getSymbolUses(Operation *from) -> Optional { - std::vector uses; - auto walkFn = [&](SymbolUse symbolUse, ArrayRef) { - uses.push_back(symbolUse); - return WalkResult::advance(); - }; - auto result = walkSymbolUses(from, walkFn); - return result ? Optional(std::move(uses)) : Optional(); + return getSymbolUsesImpl(from); +} +auto SymbolTable::getSymbolUses(Region *from) -> Optional { + return getSymbolUsesImpl(MutableArrayRef(*from)); } //===----------------------------------------------------------------------===// // SymbolTable::getSymbolUses /// The implementation of SymbolTable::getSymbolUses below. -template +template static Optional getSymbolUsesImpl(SymbolT symbol, - Operation *limit) { + IRUnitT *limit) { std::vector uses; - auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { - return walkSymbolUses( - from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { - if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())) + for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { + if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) { + if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())) uses.push_back(symbolUse); - return WalkResult::advance(); - }); - }; - if (walkSymbolScopes(symbol, limit, walkFn)) - return SymbolTable::UseRange(std::move(uses)); - return llvm::None; + })) + return llvm::None; + } + return SymbolTable::UseRange(std::move(uses)); } /// Get all of the uses of the given symbol that are nested within the given /// operation 'from', invoking the provided callback for each. This does not -/// traverse into any nested symbol tables, and will also only return uses on -/// 'from' if it does not also define a symbol table. This is because we treat -/// the region as the boundary of the symbol table, and not the op itself. This -/// function returns None if there are any unknown operations that may -/// potentially be symbol tables. +/// traverse into any nested symbol tables. This function returns None if there +/// are any unknown operations that may potentially be symbol tables. auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from) -> Optional { return getSymbolUsesImpl(symbol, from); @@ -592,37 +646,49 @@ -> Optional { return getSymbolUsesImpl(symbol, from); } +auto SymbolTable::getSymbolUses(StringRef symbol, Region *from) + -> Optional { + return getSymbolUsesImpl(symbol, from); +} +auto SymbolTable::getSymbolUses(Operation *symbol, Region *from) + -> Optional { + return getSymbolUsesImpl(symbol, from); +} //===----------------------------------------------------------------------===// // SymbolTable::symbolKnownUseEmpty /// The implementation of SymbolTable::symbolKnownUseEmpty below. -template -static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) { - // Walk all of the symbol uses looking for a reference to 'symbol'. - auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { - return walkSymbolUses( - from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { - return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()) +template +static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) { + for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { + // Walk all of the symbol uses looking for a reference to 'symbol'. + if (scope.walk([&](SymbolTable::SymbolUse symbolUse, ArrayRef) { + return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) ? WalkResult::interrupt() : WalkResult::advance(); - }); - }; - return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance(); + }) != WalkResult::advance()) + return false; + } + return true; } /// Return if the given symbol is known to have no uses that are nested within /// the given operation 'from'. This does not traverse into any nested symbol -/// tables, and will also only count uses on 'from' if it does not also define -/// a symbol table. This is because we treat the region as the boundary of the -/// symbol table, and not the op itself. This function will also return false if -/// there are any unknown operations that may potentially be symbol tables. +/// tables. This function will also return false if there are any unknown +/// operations that may potentially be symbol tables. bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) { return symbolKnownUseEmptyImpl(symbol, from); } bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { return symbolKnownUseEmptyImpl(symbol, from); } +bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Region *from) { + return symbolKnownUseEmptyImpl(symbol, from); +} +bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) { + return symbolKnownUseEmptyImpl(symbol, from); +} //===----------------------------------------------------------------------===// // SymbolTable::replaceAllSymbolUses @@ -685,10 +751,9 @@ } /// The implementation of SymbolTable::replaceAllSymbolUses below. -template -static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, - StringRef newSymbol, - Operation *limit) { +template +static LogicalResult +replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) { // A collection of operations along with their new attribute dictionary. std::vector> updatedAttrDicts; @@ -710,26 +775,26 @@ // Generate a new attribute to replace the given attribute. MLIRContext *ctx = limit->getContext(); FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx); - auto scopeWalkFn = [&](SymbolRefAttr oldAttr, - Operation *from) -> Optional { - SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr); + for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { + SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef accessChain) { SymbolRefAttr useRef = symbolUse.getSymbolRef(); - if (!isReferencePrefixOf(oldAttr, useRef)) + if (!isReferencePrefixOf(scope.symbol, useRef)) return WalkResult::advance(); // If we have a valid match, check to see if this is a proper // subreference. If it is, then we will need to generate a different new // attribute specifically for this use. SymbolRefAttr replacementRef = newAttr; - if (useRef != oldAttr) { - if (oldAttr.isa()) { + if (useRef != scope.symbol) { + if (scope.symbol.isa()) { replacementRef = SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx); } else { auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); - nestedRefs[oldAttr.getNestedReferences().size() - 1] = newLeafAttr; + nestedRefs[scope.symbol.getNestedReferences().size() - 1] = + newLeafAttr; replacementRef = SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx); } @@ -748,18 +813,15 @@ accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef}); return WalkResult::advance(); }; - if (!walkSymbolUses(from, walkFn)) - return llvm::None; + if (!scope.walk(walkFn)) + return failure(); // Check to see if we have a dangling op that needs to be processed. if (curOp) { updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); curOp = nullptr; } - return WalkResult::advance(); - }; - if (!walkSymbolScopes(symbol, limit, scopeWalkFn)) - return failure(); + } // Update the attribute dictionaries as necessary. for (auto &it : updatedAttrDicts) @@ -769,11 +831,9 @@ /// Attempt to replace all uses of the given symbol 'oldSymbol' with the /// provided symbol 'newSymbol' that are nested within the given operation -/// 'from'. This does not traverse into any nested symbol tables, and will -/// also only replace uses on 'from' if it does not also define a symbol -/// table. This is because we treat the region as the boundary of the symbol -/// table, and not the op itself. If there are any unknown operations that may -/// potentially be symbol tables, no uses are replaced and failure is returned. +/// 'from'. This does not traverse into any nested symbol tables. If there are +/// any unknown operations that may potentially be symbol tables, no uses are +/// replaced and failure is returned. LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Operation *from) { @@ -784,3 +844,13 @@ Operation *from) { return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); } +LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, + StringRef newSymbol, + Region *from) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); +} +LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, + StringRef newSymbol, + Region *from) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); +} diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -16,7 +16,7 @@ /// This is a symbol test pass that tests the symbol uselist functionality /// provided by the symbol table along with erasing from the symbol table. struct SymbolUsesPass : public ModulePass { - WalkResult operateOnSymbol(Operation *symbol, Operation *module, + WalkResult operateOnSymbol(Operation *symbol, ModuleOp module, SmallVectorImpl &deadFunctions) { // Test computing uses on a non symboltable op. Optional symbolUses = @@ -34,7 +34,7 @@ << " nested references"; // Test the functionality of symbolKnownUseEmpty. - if (SymbolTable::symbolKnownUseEmpty(symbol, module)) { + if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) { FuncOp funcSymbol = dyn_cast(symbol); if (funcSymbol && funcSymbol.isExternal()) deadFunctions.push_back(funcSymbol); @@ -44,7 +44,7 @@ } // Test the functionality of getSymbolUses. - symbolUses = SymbolTable::getSymbolUses(symbol, module); + symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion()); assert(symbolUses.hasValue() && "expected no unknown operations"); for (SymbolTable::SymbolUse symbolUse : *symbolUses) { // Check that we can resolve back to our symbol. @@ -70,10 +70,10 @@ return WalkResult::advance(); }); + SymbolTable table(module); for (Operation *op : deadFunctions) { // In order to test the SymbolTable::erase method, also erase completely // useless functions. - SymbolTable table(module); auto name = SymbolTable::getSymbolName(op); assert(table.lookup(name) && "expected no unknown operations"); table.erase(op); @@ -96,7 +96,7 @@ if (!newName) return; if (succeeded(SymbolTable::replaceAllSymbolUses( - nestedOp, newName.getValue(), module))) + nestedOp, newName.getValue(), &module.getBodyRegion()))) SymbolTable::setSymbolName(nestedOp, newName.getValue()); }); }