diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -93,6 +93,11 @@ /// with the 'OpTrait::SymbolTable' trait. static Operation *lookupSymbolIn(Operation *op, StringRef symbol); static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); + /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced + /// by a given SymbolRefAttr. Returns failure if any of the nested references + /// could not be resolved. + static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol, + SmallVectorImpl &symbols); /// Returns the operation registered with the given symbol name within the /// closest parent operation of, or including, 'from' with the diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -126,6 +126,10 @@ /// Creates a pass which inlines calls and callable operations as defined by the /// CallGraph. std::unique_ptr createInlinerPass(); + +/// Creates a pass which delete symbol operations that are unreachable. This +/// pass may *only* be scheduled on an operation that defines a SymbolTable. +std::unique_ptr createSymbolDCEPass(); } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -230,30 +230,42 @@ } Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol) { + SmallVector resolvedSymbols; + if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols))) + return nullptr; + return resolvedSymbols.back(); +} + +LogicalResult +SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol, + SmallVectorImpl &symbols) { assert(symbolTableOp->hasTrait()); // Lookup the root reference for this symbol. symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference()); if (!symbolTableOp) - return nullptr; + return failure(); + symbols.push_back(symbolTableOp); // If there are no nested references, just return the root symbol directly. ArrayRef nestedRefs = symbol.getNestedReferences(); if (nestedRefs.empty()) - return symbolTableOp; + return success(); // Verify that the root is also a symbol table. if (!symbolTableOp->hasTrait()) - return nullptr; + return failure(); // Otherwise, lookup each of the nested non-leaf references and ensure that // each corresponds to a valid symbol table. for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue()); if (!symbolTableOp || !symbolTableOp->hasTrait()) - return nullptr; + return failure(); + symbols.push_back(symbolTableOp); } - return lookupSymbolIn(symbolTableOp, symbol.getLeafReference()); + symbols.push_back(lookupSymbolIn(symbolTableOp, symbol.getLeafReference())); + return success(symbols.back()); } /// Returns the operation registered with the given symbol name within the diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ PipelineDataTransfer.cpp SimplifyAffineStructures.cpp StripDebugInfo.cpp + SymbolDCE.cpp Vectorize.cpp ViewOpGraph.cpp ViewRegionGraph.cpp diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -0,0 +1,160 @@ +//===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an algorithm for eliminating symbol operations that are +// known to be dead. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +struct SymbolDCE : public OperationPass { + void runOnOperation() override; + + /// Compute the liveness of the symbols within the given symbol table. + /// `symbolTableIsHidden` is true if this symbol table is known to be + /// unaccessible from operations in its parent regions. + LogicalResult computeLiveness(Operation *symbolTableOp, + bool symbolTableIsHidden, + DenseSet &liveSymbols); +}; +} // end anonymous namespace + +void SymbolDCE::runOnOperation() { + Operation *symbolTableOp = getOperation(); + + // SymbolDCE should only be run on operations that define a symbol table. + if (!symbolTableOp->hasTrait()) { + symbolTableOp->emitOpError() + << " was scheduled to run under SymbolDCE, but does not define a " + "symbol table"; + return signalPassFailure(); + } + + // A flag that signals if the top level symbol table is hidden, i.e. not + // accessible from parent scopes. + bool symbolTableIsHidden = true; + if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) { + symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) == + SymbolTable::Visibility::Private; + } + + // Compute the set of live symbols within the symbol table. + DenseSet liveSymbols; + if (failed(computeLiveness(symbolTableOp, symbolTableIsHidden, liveSymbols))) + return signalPassFailure(); + + // After computing the liveness, delete all of the symbols that were found to + // be dead. + symbolTableOp->walk([&](Operation *nestedSymbolTable) { + if (!nestedSymbolTable->hasTrait()) + return; + for (auto &block : nestedSymbolTable->getRegion(0)) { + for (Operation &op : + llvm::make_early_inc_range(block.without_terminator())) { + if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op)) + op.erase(); + } + } + }); +} + +/// Compute the liveness of the symbols within the given symbol table. +/// `symbolTableIsHidden` is true if this symbol table is known to be +/// unaccessible from operations in its parent regions. +LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, + bool symbolTableIsHidden, + DenseSet &liveSymbols) { + // A worklist of live operations to propagate uses from. + SmallVector worklist; + + // Walk the symbols within the current symbol table, marking the symbols that + // are known to be live. + for (auto &block : symbolTableOp->getRegion(0)) { + for (Operation &op : block.without_terminator()) { + // Always add non symbol operations to the worklist. + if (!SymbolTable::isSymbol(&op)) { + worklist.push_back(&op); + continue; + } + + // Check the visibility to see if this symbol may be referenced + // externally. + SymbolTable::Visibility visibility = + SymbolTable::getSymbolVisibility(&op); + + // Private symbols are always initially considered dead. + if (visibility == mlir::SymbolTable::Visibility::Private) + continue; + // We only include nested visibility here if the symbol table isn't + // hidden. + if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested) + continue; + + // TODO(riverriddle) Add hooks here to allow symbols to provide additional + // information, e.g. linkage can be used to drop some symbols that may + // otherwise be considered "live". + if (liveSymbols.insert(&op).second) + worklist.push_back(&op); + } + } + + // Process the set of symbols that were known to be live, adding new symbols + // that are referenced within. + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + // If this is a symbol table, recursively compute its liveness. + if (op->hasTrait()) { + // The internal symbol table is hidden if the parent is, if its not a + // symbol, or if it is a private symbol. + bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) || + SymbolTable::getSymbolVisibility(op) == + SymbolTable::Visibility::Private; + if (failed(computeLiveness(op, symbolIsHidden, liveSymbols))) + return failure(); + } + + // Collect the uses held by this operation. + Optional uses = SymbolTable::getSymbolUses(op); + if (!uses) { + return op->emitError() + << "operation contains potentially unknown symbol table, " + "meaning that we can't reliable compute symbol uses"; + } + + SmallVector resolvedSymbols; + for (const SymbolTable::SymbolUse &use : *uses) { + // Lookup the symbols referenced by this use. + resolvedSymbols.clear(); + if (failed(SymbolTable::lookupSymbolIn( + op->getParentOp(), use.getSymbolRef(), resolvedSymbols))) { + return use.getUser()->emitError() + << "unable to resolve reference to symbol " + << use.getSymbolRef(); + } + + // Mark each of the resolved symbols as live. + for (Operation *resolvedSymbol : resolvedSymbols) + if (liveSymbols.insert(resolvedSymbol).second) + worklist.push_back(resolvedSymbol); + } + } + + return success(); +} + +std::unique_ptr mlir::createSymbolDCEPass() { + return std::make_unique(); +} + +static PassRegistration pass("symbol-dce", "Eliminate dead symbols"); diff --git a/mlir/test/IR/test-symbol-dce.mlir b/mlir/test/IR/test-symbol-dce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/test-symbol-dce.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt %s -symbol-dce -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline="module(symbol-dce)" -split-input-file | FileCheck %s --check-prefix=NESTED + +// Check that trivially dead and trivially live non-nested cases are handled. + +// CHECK-LABEL: module attributes {test.simple} +module attributes {test.simple} { + // CHECK-NOT: func @dead_private_function + func @dead_private_function() attributes { sym_visibility = "nested" } + + // CHECK-NOT: func @dead_nested_function + func @dead_nested_function() attributes { sym_visibility = "nested" } + + // CHECK: func @live_private_function + func @live_private_function() attributes { sym_visibility = "nested" } + + // CHECK: func @live_nested_function + func @live_nested_function() attributes { sym_visibility = "nested" } + + // CHECK: func @public_function + func @public_function() { + "foo.return"() {uses = [@live_private_function, @live_nested_function]} : () -> () + } + + // CHECK: func @public_function_explicit + func @public_function_explicit() attributes { sym_visibility = "public" } +} + +// ----- + +// Check that we don't DCE nested symbols if they are used. +// CHECK-LABEL: module attributes {test.nested} +module attributes {test.nested} { + // CHECK: module @public_module + module @public_module { + // CHECK-NOT: func @dead_nested_function + func @dead_nested_function() attributes { sym_visibility = "nested" } + + // CHECK: func @private_function + func @private_function() attributes { sym_visibility = "private" } + + // CHECK: func @nested_function + func @nested_function() attributes { sym_visibility = "nested" } { + "foo.return"() {uses = [@private_function]} : () -> () + } + } + + "live.user"() {uses = [@public_module::@nested_function]} : () -> () +} + +// ----- + +// Check that we don't DCE symbols if we can't prove that the top-level symbol +// table that we are running on is hidden from above. +// NESTED-LABEL: module attributes {test.no_dce_non_hidden_parent} +module attributes {test.no_dce_non_hidden_parent} { + // NESTED: module @public_module + module @public_module { + // NESTED: func @nested_function + func @nested_function() attributes { sym_visibility = "nested" } + } + // NESTED: module @nested_module + module @nested_module attributes { sym_visibility = "nested" } { + // NESTED: func @nested_function + func @nested_function() attributes { sym_visibility = "nested" } + } + + // Only private modules can be assumed to be hidden. + // NESTED: module @private_module + module @private_module attributes { sym_visibility = "private" } { + // NESTED-NOT: func @nested_function + func @nested_function() attributes { sym_visibility = "nested" } + } + + "live.user"() {uses = [@nested_module, @private_module]} : () -> () +} + +// ----- + +module { + func @private_symbol() attributes { sym_visibility = "private" } + + // expected-error@+1 {{contains potentially unknown symbol table}} + "foo.possibly_unknown_symbol_table"() ({ + }) : () -> () +} + +// ----- + +module { + // expected-error@+1 {{unable to resolve reference to symbol}} + "live.user"() {uses = [@unknown_symbol]} : () -> () +}