diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -89,6 +89,11 @@ /// with the 'OpTrait::SymbolTable' trait. static Operation *lookupSymbolIn(Operation *op, StringRef symbol); static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); + /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced + /// by a given SymbolRefAttr. Returns failure if any of the nested references + /// could not be resolved. + static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol, + SmallVectorImpl &symbols); /// Returns the operation registered with the given symbol name within the /// closest parent operation of, or including, 'from' with the @@ -134,46 +139,57 @@ /// Get an iterator range for all of the uses, for any symbol, that are nested /// within the given operation 'from'. This does not traverse into any nested /// symbol tables, and will also only return uses on 'from' if it does not - /// also define a symbol table. This is because we treat the region as the - /// boundary of the symbol table, and not the op itself. This function returns - /// None if there are any unknown operations that may potentially be symbol - /// tables. - static Optional getSymbolUses(Operation *from); + /// also define a symbol table. If `recurseIfSymTab` is false and 'from' is a + /// symbol table, this will only return uses on 'from' and not any nested + /// within. This is because we treat the region as the boundary of the symbol + /// table, and not the op itself. This function returns None if there are any + /// unknown operations that may potentially be symbol tables. + static Optional getSymbolUses(Operation *from, + bool recurseIfSymTab = true); /// Get all of the uses of the given symbol that are nested within the given /// operation 'from'. This does not traverse into any nested symbol tables, /// and will also only return uses on 'from' if it does not also define a - /// symbol table. This is because we treat the region as the boundary of the - /// symbol table, and not the op itself. This function returns None if there - /// are any unknown operations that may potentially be symbol tables. - static Optional getSymbolUses(StringRef symbol, Operation *from); - static Optional getSymbolUses(Operation *symbol, Operation *from); - - /// Return if the given symbol is known to have no uses that are nested - /// within the given operation 'from'. This does not traverse into any nested - /// symbol tables, and will also only count uses on 'from' if it does not also - /// define a symbol table. This is because we treat the region as the boundary - /// of the symbol table, and not the op itself. This function will also return - /// false if there are any unknown operations that may potentially be symbol - /// tables. This doesn't necessarily mean that there are no uses, we just - /// can't conservatively prove it. - static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); - static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); + /// symbol table. If `recurseIfSymTab` is false and 'from' is a symbol table, + /// this will only return uses on 'from' and not any nested within. This is + /// because we treat the region as the boundary of the symbol table, and not + /// the op itself. This function returns None if there are any unknown + /// operations that may potentially be symbol tables. + static Optional getSymbolUses(StringRef symbol, Operation *from, + bool recurseIfSymTab = true); + static Optional getSymbolUses(Operation *symbol, Operation *from, + bool recurseIfSymTab = true); + + /// Return if the given symbol is known to have no uses that are nested within + /// the given operation 'from'. This does not traverse into any nested symbol + /// tables, and will also only count uses on 'from' if it does not also define + /// a symbol table. If `recurseIfSymTab` is false and 'from' is a symbol + /// table, this will only return uses on 'from' and not any nested within. + /// This is because we treat the region as the boundary of the symbol table, + /// and not the op itself. This function will also return false if there are + /// any unknown operations that may potentially be symbol tables. This doesn't + /// necessarily mean that there are no uses, we just can't convervatively + /// prove it. + static bool symbolKnownUseEmpty(StringRef symbol, Operation *from, + bool recurseIfSymTab = true); + static bool symbolKnownUseEmpty(Operation *symbol, Operation *from, + bool recurseIfSymTab = true); /// Attempt to replace all uses of the given symbol 'oldSymbol' with the /// provided symbol 'newSymbol' that are nested within the given operation /// 'from'. This does not traverse into any nested symbol tables, and will /// also only replace uses on 'from' if it does not also define a symbol - /// table. This is because we treat the region as the boundary of the symbol - /// table, and not the op itself. If there are any unknown operations that may - /// potentially be symbol tables, no uses are replaced and failure is - /// returned. - LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, - StringRef newSymbol, - Operation *from); + /// table. If `recurseIfSymTab` is false and 'from' is a symbol table, this + /// will only return uses on 'from' and not any nested within. This is because + /// we treat the region as the boundary of the symbol table, and not the op + /// itself. If there are any unknown operations that may potentially be symbol + /// tables, no uses are replaced and failure is returned. + LLVM_NODISCARD static LogicalResult + replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, + Operation *from, bool recurseIfSymTab = true); LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, - Operation *from); + Operation *from, bool recurseIfSymTab = true); private: Operation *symbolTableOp; diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -126,6 +126,11 @@ /// Creates a pass which inlines calls and callable operations as defined by the /// CallGraph. std::unique_ptr createInlinerPass(); + +/// Creates a pass which delete symbol operations that have are unreachable. +/// This pass may *only* be scheduled on an operation that defines a +/// SymbolTable. +std::unique_ptr createSymbolDCEPass(); } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -230,30 +230,42 @@ } Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol) { + SmallVector resolvedSymbols; + if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols))) + return nullptr; + return resolvedSymbols.back(); +} + +LogicalResult +SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol, + SmallVectorImpl &symbols) { assert(symbolTableOp->hasTrait()); // Lookup the root reference for this symbol. symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference()); if (!symbolTableOp) - return nullptr; + return failure(); + symbols.push_back(symbolTableOp); // If there are no nested references, just return the root symbol directly. ArrayRef nestedRefs = symbol.getNestedReferences(); if (nestedRefs.empty()) - return symbolTableOp; + return success(); // Verify that the root is also a symbol table. if (!symbolTableOp->hasTrait()) - return nullptr; + return failure(); // Otherwise, lookup each of the nested non-leaf references and ensure that // each corresponds to a valid symbol table. for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue()); if (!symbolTableOp || !symbolTableOp->hasTrait()) - return nullptr; + return failure(); + symbols.push_back(symbolTableOp); } - return lookupSymbolIn(symbolTableOp, symbol.getLeafReference()); + symbols.push_back(lookupSymbolIn(symbolTableOp, symbol.getLeafReference())); + return success(symbols.back()); } /// Returns the operation registered with the given symbol name within the @@ -405,13 +417,26 @@ /// traverse into any nested symbol tables, and will also only return uses on /// 'from' if it does not also define a symbol table. static Optional walkSymbolUses( - Operation *from, + Operation *from, bool recurseIfSymTab, function_ref)> callback) { + // If this operation has regions, and it, as well as its dialect, isn't + // registered then conservatively fail. The operation may define a + // symbol table, so we can't opaquely know if we should traverse to find + // nested uses. + if (isPotentiallyUnknownSymbolTable(from)) + return llvm::None; + // If from is not a symbol table, check for uses. A symbol table defines a new // scope, so we can't walk the attributes from the symbol table op. - if (!from->hasTrait()) { + bool isSymTab = from->hasTrait(); + if (!isSymTab || !recurseIfSymTab) { if (walkSymbolRefs(from, callback).wasInterrupted()) return WalkResult::interrupt(); + + // If this was a symbol table, early exit to avoid recursing into nested + // uses. + if (isSymTab) + return WalkResult::advance(); } SmallVector worklist; @@ -426,10 +451,7 @@ if (walkSymbolRefs(&op, callback).wasInterrupted()) return WalkResult::interrupt(); - // If this operation has regions, and it as well as its dialect aren't - // registered then conservatively fail. The operation may define a - // symbol table, so we can't opaquely know if we should traverse to find - // nested uses. + // Check that this isn't a potentially unknown symbol table. if (isPotentiallyUnknownSymbolTable(&op)) return llvm::None; @@ -451,8 +473,9 @@ /// current scope as well as the top-level operation representing the top of /// that scope. static Optional walkSymbolScopes( - Operation *symbol, Operation *limit, - function_ref(SymbolRefAttr, Operation *)> callback) { + Operation *symbol, Operation *limit, bool recurseIfSymTab, + function_ref(SymbolRefAttr, Operation *, bool)> + callback) { StringRef symbolName = SymbolTable::getSymbolName(symbol); assert(!symbol->hasTrait() || symbol != limit); @@ -469,7 +492,7 @@ if (getNearestSymbolTable(limit) != symbol->getParentOp()) return WalkResult::advance(); return callback(SymbolRefAttr::get(symbolName, symbol->getContext()), - limit); + limit, recurseIfSymTab); } limitAncestors.insert(limitAncestor); @@ -494,28 +517,32 @@ // Walk each of the ancestors of 'symbol', calling the compute function for // each one. Operation *limitIt = symbol->getParentOp(); - for (size_t i = 0, e = references.size(); i != e; + for (size_t i = 0, e = references.size() - 1; i != e; ++i, limitIt = limitIt->getParentOp()) { - Optional callbackResult = callback(references[i], limitIt); + auto callbackResult = callback(references[i], limitIt, recurseIfSymTab); if (callbackResult != WalkResult::advance()) return callbackResult; + + // After the first iteration, always recurse into symbol tables. + recurseIfSymTab = true; } + // Fall through to walk the top level parent. + } else if (!collectedAllReferences) { + // Otherwise, we need just need the symbol reference for 'symbol' that will + // be used within 'limit'. This is the last reference in the list we + // computed above if we were able to collect all references. return WalkResult::advance(); } - - // Otherwise, we just need the symbol reference for 'symbol' that will be - // used within 'limit'. This is the last reference in the list we computed - // above if we were able to collect all references. - if (!collectedAllReferences) - return WalkResult::advance(); - return callback(references.back(), limit); + return callback(references.back(), limit, recurseIfSymTab); } /// Walk the symbol scopes defined by 'limit' invoking the provided callback. static Optional walkSymbolScopes( - StringRef symbol, Operation *limit, - function_ref(SymbolRefAttr, Operation *)> callback) { - return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit); + StringRef symbol, Operation *limit, bool recurseIfSymTab, + function_ref(SymbolRefAttr, Operation *, bool)> + callback) { + return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit, + recurseIfSymTab); } /// Returns true if the given reference 'SubRef' is a sub reference of the @@ -546,13 +573,14 @@ /// boundary of the symbol table, and not the op itself. This function returns /// None if there are any unknown operations that may potentially be symbol /// tables. -auto SymbolTable::getSymbolUses(Operation *from) -> Optional { +auto SymbolTable::getSymbolUses(Operation *from, bool recurseIfSymTab) + -> Optional { std::vector uses; auto walkFn = [&](SymbolUse symbolUse, ArrayRef) { uses.push_back(symbolUse); return WalkResult::advance(); }; - auto result = walkSymbolUses(from, walkFn); + auto result = walkSymbolUses(from, recurseIfSymTab, walkFn); return result ? Optional(std::move(uses)) : Optional(); } @@ -561,18 +589,20 @@ /// The implementation of SymbolTable::getSymbolUses below. template -static Optional getSymbolUsesImpl(SymbolT symbol, - Operation *limit) { +static Optional +getSymbolUsesImpl(SymbolT symbol, Operation *limit, bool recurseIfSymTab) { std::vector uses; - auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { + auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from, + bool recurseIfSymTab) { return walkSymbolUses( - from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { + from, recurseIfSymTab, + [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())) uses.push_back(symbolUse); return WalkResult::advance(); }); }; - if (walkSymbolScopes(symbol, limit, walkFn)) + if (walkSymbolScopes(symbol, limit, recurseIfSymTab, walkFn)) return SymbolTable::UseRange(std::move(uses)); return llvm::None; } @@ -584,13 +614,13 @@ /// the region as the boundary of the symbol table, and not the op itself. This /// function returns None if there are any unknown operations that may /// potentially be symbol tables. -auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from) - -> Optional { - return getSymbolUsesImpl(symbol, from); +auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from, + bool recurseIfSymTab) -> Optional { + return getSymbolUsesImpl(symbol, from, recurseIfSymTab); } -auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from) - -> Optional { - return getSymbolUsesImpl(symbol, from); +auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from, + bool recurseIfSymTab) -> Optional { + return getSymbolUsesImpl(symbol, from, recurseIfSymTab); } //===----------------------------------------------------------------------===// @@ -598,17 +628,21 @@ /// The implementation of SymbolTable::symbolKnownUseEmpty below. template -static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) { +static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit, + bool recurseIfSymTab) { // Walk all of the symbol uses looking for a reference to 'symbol'. - auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { - return walkSymbolUses( - from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { - return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()) - ? WalkResult::interrupt() - : WalkResult::advance(); - }); + auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from, + bool recurseIfSymTab) { + return walkSymbolUses(from, recurseIfSymTab, + [&](SymbolTable::SymbolUse symbolUse, ArrayRef) { + return isReferencePrefixOf(symbolRefAttr, + symbolUse.getSymbolRef()) + ? WalkResult::interrupt() + : WalkResult::advance(); + }); }; - return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance(); + return walkSymbolScopes(symbol, limit, recurseIfSymTab, walkFn) == + WalkResult::advance(); } /// Return if the given symbol is known to have no uses that are nested within @@ -617,11 +651,13 @@ /// a symbol table. This is because we treat the region as the boundary of the /// symbol table, and not the op itself. This function will also return false if /// there are any unknown operations that may potentially be symbol tables. -bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) { - return symbolKnownUseEmptyImpl(symbol, from); +bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from, + bool recurseIfSymTab) { + return symbolKnownUseEmptyImpl(symbol, from, recurseIfSymTab); } -bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { - return symbolKnownUseEmptyImpl(symbol, from); +bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from, + bool recurseIfSymTab) { + return symbolKnownUseEmptyImpl(symbol, from, recurseIfSymTab); } //===----------------------------------------------------------------------===// @@ -686,9 +722,9 @@ /// The implementation of SymbolTable::replaceAllSymbolUses below. template -static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, - StringRef newSymbol, - Operation *limit) { +static LogicalResult +replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, Operation *limit, + bool recurseIfSymTab) { // A collection of operations along with their new attribute dictionary. std::vector> updatedAttrDicts; @@ -710,8 +746,8 @@ // Generate a new attribute to replace the given attribute. MLIRContext *ctx = limit->getContext(); FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx); - auto scopeWalkFn = [&](SymbolRefAttr oldAttr, - Operation *from) -> Optional { + auto scopeWalkFn = [&](SymbolRefAttr oldAttr, Operation *from, + bool recurseIfSymTab) -> Optional { SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr); auto walkFn = [&](SymbolTable::SymbolUse symbolUse, ArrayRef accessChain) { @@ -748,7 +784,7 @@ accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef}); return WalkResult::advance(); }; - if (!walkSymbolUses(from, walkFn)) + if (!walkSymbolUses(from, recurseIfSymTab, walkFn)) return llvm::None; // Check to see if we have a dangling op that needs to be processed. @@ -758,7 +794,7 @@ } return WalkResult::advance(); }; - if (!walkSymbolScopes(symbol, limit, scopeWalkFn)) + if (!walkSymbolScopes(symbol, limit, recurseIfSymTab, scopeWalkFn)) return failure(); // Update the attribute dictionaries as necessary. @@ -776,11 +812,13 @@ /// potentially be symbol tables, no uses are replaced and failure is returned. LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, - Operation *from) { - return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); + Operation *from, + bool recurseIfSymTab) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from, recurseIfSymTab); } LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbol, - Operation *from) { - return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); + Operation *from, + bool recurseIfSymTab) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from, recurseIfSymTab); } diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ PipelineDataTransfer.cpp SimplifyAffineStructures.cpp StripDebugInfo.cpp + SymbolDCE.cpp Vectorize.cpp ViewOpGraph.cpp ViewRegionGraph.cpp diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -0,0 +1,162 @@ +//===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an algorithm for eliminating symbol operations that are +// known to be dead. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +struct SymbolDCE : public OperationPass { + void runOnOperation() override; + + /// Compute the liveness of the symbols within the given symbol table. + /// `symbolTableIsHidden` is true if this symbol table is known to be + /// unaccessible from operations in its parent regions. + LogicalResult computeLiveness(Operation *symbolTableOp, + bool symbolTableIsHidden, + DenseSet &liveSymbols); +}; +} // end anonymous namespace + +void SymbolDCE::runOnOperation() { + Operation *symbolTableOp = getOperation(); + + // SymbolDCE should only be run on operations that define a symbol table. + if (!symbolTableOp->hasTrait()) { + symbolTableOp->emitOpError() + << " was scheduled to run under SymbolDCE, but does not define a " + "symbol table"; + return signalPassFailure(); + } + + // A flag that signals if the top level symbol table is hidden, i.e. not + // accessible from parent scopes. + bool symbolTableIsHidden = true; + if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) { + symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) == + SymbolTable::Visibility::Private; + } + + // Compute the set of live symbols within the symbol table. + DenseSet liveSymbols; + if (failed(computeLiveness(symbolTableOp, symbolTableIsHidden, liveSymbols))) + return signalPassFailure(); + + // After computing the liveness, delete all of the symbols that were found to + // be dead. + symbolTableOp->walk([&](Operation *nestedSymbolTable) { + if (!nestedSymbolTable->hasTrait()) + return; + for (auto &block : nestedSymbolTable->getRegion(0)) { + for (Operation &op : + llvm::make_early_inc_range(block.without_terminator())) { + if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op)) + op.erase(); + } + } + }); +} + +/// Compute the liveness of the symbols within the given symbol table. +/// `symbolTableIsHidden` is true if this symbol table is known to be +/// unaccessible from operations in its parent regions. +LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, + bool symbolTableIsHidden, + DenseSet &liveSymbols) { + // A worklist of live operations to propagate uses from. + SmallVector worklist; + + // Walk the symbols within the current symbol table, marking the symbols that + // are known to be live. + for (auto &block : symbolTableOp->getRegion(0)) { + for (Operation &op : block.without_terminator()) { + // Always add non symbol operations to the worklist. + if (!SymbolTable::isSymbol(&op)) { + worklist.push_back(&op); + continue; + } + + // Check the visibility to see if this symbol may be referenced + // externally. + SymbolTable::Visibility visibility = + SymbolTable::getSymbolVisibility(&op); + + // Private symbols are always initially considered dead. + if (visibility == mlir::SymbolTable::Visibility::Private) + continue; + // We only include nested visibility here if the symbol table isn't + // hidden. + if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested) + continue; + + // TODO(riverriddle) Add hooks here to allow symbols to provide additional + // information, e.g. linakage can be used to drop some symbols that may + // otherwise be considered "live". + if (liveSymbols.insert(&op).second) + worklist.push_back(&op); + } + } + + // Process the set of symbols that were known to be live, adding new symbols + // that are referenced within. + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + // Make sure not to recurse when computing uses if this is a symbol table, + // we only want uses within the current 'symbolTableOp'. + Optional uses = + SymbolTable::getSymbolUses(op, /*recurseIfSymTab=*/false); + if (!uses) { + return op->emitError() + << "operation contains potentially unknown symbol table, " + "meaning that we can't reliable compute symbol uses"; + } + + // If this is a symbol table, recursively compute its liveness. + if (op->hasTrait()) { + // The internal symbol table is hidden if the parent is, if its not a + // symbol, or if it is a private symbol. + bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) || + SymbolTable::getSymbolVisibility(op) == + SymbolTable::Visibility::Private; + if (failed(computeLiveness(op, symbolIsHidden, liveSymbols))) + return failure(); + } + + SmallVector resolvedSymbols; + for (const SymbolTable::SymbolUse &use : *uses) { + // Lookup the symbols referenced by this use. + resolvedSymbols.clear(); + if (failed(SymbolTable::lookupSymbolIn( + op->getParentOp(), use.getSymbolRef(), resolvedSymbols))) { + return use.getUser()->emitError() + << "unable to resolve reference to symbol " + << use.getSymbolRef(); + } + + // Mark each of the resolved symbols as live. + for (Operation *resolvedSymbol : resolvedSymbols) + if (liveSymbols.insert(resolvedSymbol).second) + worklist.push_back(resolvedSymbol); + } + } + + return success(); +} + +std::unique_ptr mlir::createSymbolDCEPass() { + return std::make_unique(); +} + +static PassRegistration pass("symbol-dce", "Eliminate dead symbols"); diff --git a/mlir/test/IR/test-symbol-dce.mlir b/mlir/test/IR/test-symbol-dce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/test-symbol-dce.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt %s -symbol-dce -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline="module(symbol-dce)" -split-input-file | FileCheck %s --check-prefix=NESTED + +// Check that trivially dead and trivially live non-nested cases are handled. + +// CHECK-LABEL: module attributes {test.simple} +module attributes {test.simple} { + // CHECK-NOT: func @dead_private_function + func @dead_private_function() attributes { sym_visibility = "nested" } + + // CHECK-NOT: func @dead_nested_function + func @dead_nested_function() attributes { sym_visibility = "nested" } + + // CHECK: func @live_private_function + func @live_private_function() attributes { sym_visibility = "nested" } + + // CHECK: func @live_nested_function + func @live_nested_function() attributes { sym_visibility = "nested" } + + // CHECK: func @public_function + func @public_function() { + "foo.return"() {uses = [@live_private_function, @live_nested_function]} : () -> () + } + + // CHECK: func @public_function_explicit + func @public_function_explicit() attributes { sym_visibility = "public" } +} + +// ----- + +// Check that we don't DCE nested symbols if they are used. +// CHECK-LABEL: module attributes {test.nested} +module attributes {test.nested} { + // CHECK: module @public_module + module @public_module { + // CHECK-NOT: func @dead_nested_function + func @dead_nested_function() attributes { sym_visibility = "nested" } + + // CHECK: func @private_function + func @private_function() attributes { sym_visibility = "private" } + + // CHECK: func @nested_function + func @nested_function() attributes { sym_visibility = "nested" } { + "foo.return"() {uses = [@private_function]} : () -> () + } + } + + "live.user"() {uses = [@public_module::@nested_function]} : () -> () +} + +// ----- + +// Check that we don't DCE symbols if we can't prove that the top-level symbol +// table that we are running on is hidden from above. +// NESTED-LABEL: module attributes {test.no_dce_non_hidden_parent} +module attributes {test.no_dce_non_hidden_parent} { + // NESTED: module @public_module + module @public_module { + // NESTED: func @nested_function + func @nested_function() attributes { sym_visibility = "nested" } + } + // NESTED: module @nested_module + module @nested_module attributes { sym_visibility = "nested" } { + // NESTED: func @nested_function + func @nested_function() attributes { sym_visibility = "nested" } + } + + // Only private modules can be assumed to be hidden. + // NESTED: module @private_module + module @private_module attributes { sym_visibility = "private" } { + // NESTED-NOT: func @nested_function + func @nested_function() attributes { sym_visibility = "nested" } + } + + "live.user"() {uses = [@nested_module, @private_module]} : () -> () +} + +// ----- + +module { + func @private_symbol() attributes { sym_visibility = "private" } + + // expected-error@+1 {{contains potentially unknown symbol table}} + "foo.possibly_unknown_symbol_table"() ({ + }) : () -> () +} + +// ----- + +module { + // expected-error@+1 {{unable to resolve reference to symbol}} + "live.user"() {uses = [@unknown_symbol]} : () -> () +}