diff --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlowAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h @@ -22,34 +22,17 @@ #ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H #define MLIR_ANALYSIS_DATAFLOWANALYSIS_H +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" #include "llvm/Support/Allocator.h" -namespace mlir { -//===----------------------------------------------------------------------===// -// ChangeResult -//===----------------------------------------------------------------------===// - -/// A result type used to indicate if a change happened. Boolean operations on -/// ChangeResult behave as though `Change` is truthy. -enum class ChangeResult { - NoChange, - Change, -}; -inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) { - return lhs == ChangeResult::Change ? lhs : rhs; -} -inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) { - lhs = lhs | rhs; - return lhs; -} -inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) { - return lhs == ChangeResult::NoChange ? lhs : rhs; -} +/// TODO: Remove this file when SCCP and integer range analysis have been ported +/// to the new framework. +namespace mlir { //===----------------------------------------------------------------------===// // AbstractLatticeElement //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -16,7 +16,6 @@ #ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H #define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H -#include "mlir/Analysis/DataFlowAnalysis.h" #include "mlir/IR/Operation.h" #include "mlir/Support/StorageUniquer.h" #include "llvm/ADT/SetVector.h" @@ -25,6 +24,27 @@ namespace mlir { +//===----------------------------------------------------------------------===// +// ChangeResult +//===----------------------------------------------------------------------===// + +/// A result type used to indicate if a change happened. Boolean operations on +/// ChangeResult behave as though `Change` is truthy. +enum class ChangeResult { + NoChange, + Change, +}; +inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) { + return lhs == ChangeResult::Change ? lhs : rhs; +} +inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) { + lhs = lhs | rhs; + return lhs; +} +inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) { + return lhs == ChangeResult::NoChange ? lhs : rhs; +} + /// Forward declare the analysis state class. class AnalysisState; diff --git a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h @@ -0,0 +1,429 @@ +//===- SparseDataFlowAnalysis.h - Sparse data-flow analysis ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements sparse data-flow analysis using the data-flow analysis +// framework. The analysis is forward and conditional and uses the results of +// dead code analysis to prune dead code during the analysis. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H +#define MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/SmallPtrSet.h" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// AbstractLattice +//===----------------------------------------------------------------------===// + +/// This class represents an abstract lattice. A lattice contains information +/// about an SSA value and is what's propagated across the IR by sparse +/// data-flow analysis. +class AbstractLattice : public AnalysisState { +public: + /// Lattices can only be created for values. + AbstractLattice(Value value) : AnalysisState(value) {} + + /// Join the information contained in 'rhs' into this lattice. Returns + /// if the value of the lattice changed. + virtual ChangeResult join(const AbstractLattice &rhs) = 0; + + /// Returns true if the lattice element is at fixpoint and further calls to + /// `join` will not update the value of the element. + virtual bool isAtFixpoint() const = 0; + + /// Mark the lattice element as having reached a pessimistic fixpoint. This + /// means that the lattice may potentially have conflicting value states, and + /// only the most conservative value should be relied on. + virtual ChangeResult markPessimisticFixpoint() = 0; + + /// Mark the lattice element as having reached an optimistic fixpoint. This + /// means that we optimistically assume the current value is the true state. + virtual void markOptimisticFixpoint() = 0; + + /// When the lattice gets updated, propagate an update to users of the value + /// using its use-def chain to subscribed analyses. + void onUpdate(DataFlowSolver *solver) const override; + + /// Subscribe an analysis to updates of the lattice. When the lattice changes, + /// subscribed analyses are re-invoked on all users of the value. This is + /// more efficient than relying on the dependency map. + void useDefSubscribe(DataFlowAnalysis *analysis) { + useDefSubscribers.insert(analysis); + } + +private: + /// A set of analyses that should be updated when this lattice changes. + SetVector, + SmallPtrSet> + useDefSubscribers; +}; + +//===----------------------------------------------------------------------===// +// Lattice +//===----------------------------------------------------------------------===// + +/// This class represents a lattice holding a specific value of type `ValueT`. +/// Lattice values (`ValueT`) are required to adhere to the following: +/// +/// * static ValueT join(const ValueT &lhs, const ValueT &rhs); +/// - This method conservatively joins the information held by `lhs` +/// and `rhs` into a new value. This method is required to be monotonic. +/// * bool operator==(const ValueT &rhs) const; +/// +template +class Lattice : public AbstractLattice { +public: + using AbstractLattice::AbstractLattice; + + /// Get a lattice element with a known value. + Lattice(const ValueT &knownValue = ValueT()) + : AbstractLattice(Value()), knownValue(knownValue) {} + + /// Return the value held by this lattice. This requires that the value is + /// initialized. + ValueT &getValue() { + assert(!isUninitialized() && "expected known lattice element"); + return *optimisticValue; + } + const ValueT &getValue() const { + return const_cast *>(this)->getValue(); + } + + /// Returns true if the value of this lattice hasn't yet been initialized. + bool isUninitialized() const override { return !optimisticValue.hasValue(); } + /// Force the initialization of the element by setting it to its pessimistic + /// fixpoint. + ChangeResult defaultInitialize() override { + return markPessimisticFixpoint(); + } + + /// Returns true if the lattice has reached a fixpoint. A fixpoint is when + /// the information optimistically assumed to be true is the same as the + /// information known to be true. + bool isAtFixpoint() const override { return optimisticValue == knownValue; } + + /// Join the information contained in the 'rhs' lattice into this + /// lattice. Returns if the state of the current lattice changed. + ChangeResult join(const AbstractLattice &rhs) override { + const Lattice &rhsLattice = + static_cast &>(rhs); + + // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do. + if (isAtFixpoint() || rhsLattice.isUninitialized()) + return ChangeResult::NoChange; + + // Join the rhs value into this lattice. + return join(rhsLattice.getValue()); + } + + /// Join the information contained in the 'rhs' value into this + /// lattice. Returns if the state of the current lattice changed. + ChangeResult join(const ValueT &rhs) { + // If the current lattice is uninitialized, copy the rhs value. + if (isUninitialized()) { + optimisticValue = rhs; + return ChangeResult::Change; + } + + // Otherwise, join rhs with the current optimistic value. + ValueT newValue = ValueT::join(*optimisticValue, rhs); + assert(ValueT::join(newValue, *optimisticValue) == newValue && + "expected `join` to be monotonic"); + assert(ValueT::join(newValue, rhs) == newValue && + "expected `join` to be monotonic"); + + // Update the current optimistic value if something changed. + if (newValue == optimisticValue) + return ChangeResult::NoChange; + + optimisticValue = newValue; + return ChangeResult::Change; + } + + /// Mark the lattice element as having reached a pessimistic fixpoint. This + /// means that the lattice may potentially have conflicting value states, + /// and only the conservatively known value state should be relied on. + ChangeResult markPessimisticFixpoint() override { + if (isAtFixpoint()) + return ChangeResult::NoChange; + + // For this fixed point, we take whatever we knew to be true and set that + // to our optimistic value. + optimisticValue = knownValue; + return ChangeResult::Change; + } + + /// Mark the lattice element as having reached an optimistic fixpoint. This + /// means that we optimistically assume the current value is the true state. + void markOptimisticFixpoint() override { + assert(!isUninitialized() && "expected an initialized value"); + knownValue = *optimisticValue; + } + + /// Print the lattice element. + void print(raw_ostream &os) const override { + os << "["; + knownValue.print(os); + os << ", "; + if (optimisticValue) { + optimisticValue->print(os); + } else { + os << ""; + } + os << "]"; + } + +private: + /// The value that is conservatively known to be true. + ValueT knownValue; + /// The currently computed value that is optimistically assumed to be true, + /// or None if the lattice element is uninitialized. + Optional optimisticValue; +}; + +//===----------------------------------------------------------------------===// +// Executable +//===----------------------------------------------------------------------===// + +/// This is a simple analysis state that represents whether the associated +/// program point (either a block or a control-flow edge) is live. +class Executable : public AnalysisState { +public: + using AnalysisState::AnalysisState; + + /// The state is initialized by default. + bool isUninitialized() const override { return false; } + + /// The state is always initialized. + ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } + + /// Set the state of the program point to live. + ChangeResult setToLive(); + + /// Get whether the program point is live. + bool isLive() const { return live; } + + /// Print the liveness; + void print(raw_ostream &os) const override; + + /// When the state of the program point is changed to live, re-invoke + /// subscribed analyses on the operations in the block and on the block + /// itself. + void onUpdate(DataFlowSolver *solver) const override; + + /// Subscribe an analysis to changes to the liveness. + void blockContentSubscribe(DataFlowAnalysis *analysis) { + subscribers.insert(analysis); + } + +private: + /// Whether the program point is live. Optimistically assume that the program + /// point is dead. + bool live = false; + + /// A set of analyses that should be updated when this state changes. + SetVector, + SmallPtrSet> + subscribers; +}; + +//===----------------------------------------------------------------------===// +// ConstantValue +//===----------------------------------------------------------------------===// + +/// This lattice value represents a known constant value of a lattice. +class ConstantValue { +public: + /// Construct a constant value with a known constant. + ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr) + : constant(knownValue), dialect(dialect) {} + + /// Get the constant value. Returns null if no value was determined. + Attribute getConstantValue() const { return constant; } + + /// Get the dialect instance that can be used to materialize the constant. + Dialect *getConstantDialect() const { return dialect; } + + /// Compare the constant values. + bool operator==(const ConstantValue &rhs) const { + return constant == rhs.constant; + } + + /// The union with another constant value is null if they are different, and + /// the same if they are the same. + static ConstantValue join(const ConstantValue &lhs, + const ConstantValue &rhs) { + return lhs == rhs ? lhs : ConstantValue(); + } + + /// Print the constant value. + void print(raw_ostream &os) const; + +private: + /// The constant value. + Attribute constant; + /// An dialect instance that can be used to materialize the constant. + Dialect *dialect; +}; + +//===----------------------------------------------------------------------===// +// PredecessorState +//===----------------------------------------------------------------------===// + +/// This analysis state represents a set of known predecessors. This state is +/// used in sparse data-flow analysis to reason about region control-flow and +/// callgraphs. The state may also indicate that not all predecessors can be +/// known, if for example not all callsites of a callable are visible. +class PredecessorState : public AnalysisState { +public: + using AnalysisState::AnalysisState; + + /// The state is initialized by default. + bool isUninitialized() const override { return false; } + + /// The state is always initialized. + ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } + + /// Print the known predecessors. + void print(raw_ostream &os) const override; + + /// Returns true if all predecessors are known. + bool allPredecessorsKnown() const { return allKnown; } + + /// Indicate that there are potentially unknown predecessors. + ChangeResult setHasUnknownPredecessors() { + if (!allKnown) + return ChangeResult::NoChange; + allKnown = false; + return ChangeResult::Change; + } + + /// Get the known predecessors. + ArrayRef getKnownPredecessors() const { + return predecessors.getArrayRef(); + } + + /// Add a known predecessor. + ChangeResult join(Operation *predecessor) { + return predecessors.insert(predecessor) ? ChangeResult::Change + : ChangeResult::NoChange; + } + +private: + /// Whether all predecessors are known. Optimistically assume that we known + /// all predecessors. + bool allKnown = true; + + /// The known control-flow predecessors of this program point. + SetVector, + SmallPtrSet> + predecessors; +}; + +//===----------------------------------------------------------------------===// +// CFGEdge +//===----------------------------------------------------------------------===// + +/// This program point represents a control-flow edge between a block and one +/// of its successors. +class CFGEdge + : public GenericProgramPointBase> { +public: + using Base::Base; + + /// Get the block from which the edge originates. + Block *getFrom() const { return getValue().first; } + /// Get the target block. + Block *getTo() const { return getValue().second; } + + /// Print the blocks between the control-flow edge. + void print(raw_ostream &os) const override; + /// Get a fused location of both blocks. + Location getLoc() const override; +}; + +//===----------------------------------------------------------------------===// +// DeadCodeAnalysis +//===----------------------------------------------------------------------===// + +/// Dead code analysis analyzes control-flow, as understood by +/// `RegionBranchOpInterface` and `BranchOpInterface`, and the callgraph, as +/// understood by `CallableOpInterface` and `CallOpInterface`. +/// +/// This analysis uses known constant values of operands to determine the +/// liveness of each block and each edge between a block and its predecessors. +/// For region control-flow, this analysis determines the predecessor operations +/// for region entry blocks and region control-flow operations. For the +/// callgraph, this analysis determines the callsites and live returns of every +/// function. +class DeadCodeAnalysis : public DataFlowAnalysis { +public: + explicit DeadCodeAnalysis(DataFlowSolver &solver); + + /// Initialize the analysis by visiting every operation with potential + /// control-flow semantics. + LogicalResult initialize(Operation *top) override; + + /// Visit an operation with control-flow semantics and deduce which of its + /// successors are live. + LogicalResult visit(ProgramPoint point) override; + +private: + /// Find and mark symbol callables with potentially unknown callsites as + /// having overdefined predecessors. `top` is the top-level operation that the + /// analysis is operating on. + void initializeSymbolCallables(Operation *top); + + /// Recursively Initialize the analysis on nested regions. + LogicalResult initializeRecursively(Operation *op); + + /// Visit the given call operation and compute any necessary lattice state. + void visitCallOperation(CallOpInterface call); + + /// Visit the given branch operation with successors and try to determine + /// which are live from the current block. + void visitBranchOperation(BranchOpInterface branch); + + /// Visit the given region branch operation, which defines regions, and + /// compute any necessary lattice state. This also resolves the lattice state + /// of both the operation results and any nested regions. + void visitRegionBranchOperation(RegionBranchOpInterface branch); + + /// Visit the given terminator operation that exits a region under an + /// operation with control-flow semantics. These are terminators with no CFG + /// successors. + void visitRegionTerminator(Operation *op, RegionBranchOpInterface branch); + + /// Visit the given terminator operation that exits a callable region. These + /// are terminators with no CFG successors. + void visitCallableTerminator(Operation *op, CallableOpInterface callable); + + /// Mark the edge between `from` and `to` as executable. + void markEdgeLive(Block *from, Block *to); + + /// Mark the entry blocks of the operation as executable. + void markEntryBlocksLive(Operation *op); + + /// Get the constant values of the operands of the operation. Returns none if + /// any of the operand lattices are uninitialized. + Optional> getOperandValues(Operation *op); + + /// A symbol table used for O(1) symbol lookups during simplification. + SymbolTableCollection symbolTable; +}; + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -7,6 +7,7 @@ IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp + SparseDataFlowAnalysis.cpp AliasAnalysis/LocalAliasAnalysis.cpp ) @@ -21,6 +22,7 @@ IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp + SparseDataFlowAnalysis.cpp AliasAnalysis/LocalAliasAnalysis.cpp diff --git a/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp @@ -0,0 +1,413 @@ +//===- SparseDataFlowAnalysis.cpp - Sparse data-flow analysis -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SparseDataFlowAnalysis.h" + +#define DEBUG_TYPE "dataflow" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// AbstractLattice +//===----------------------------------------------------------------------===// + +void AbstractLattice::onUpdate(DataFlowSolver *solver) const { + // Push all users of the value to the queue. + for (Operation *user : point.get().getUsers()) + for (DataFlowAnalysis *analysis : useDefSubscribers) + solver->enqueue({user, analysis}); +} + +//===----------------------------------------------------------------------===// +// Executable +//===----------------------------------------------------------------------===// + +ChangeResult Executable::setToLive() { + if (live) + return ChangeResult::NoChange; + live = true; + return ChangeResult::Change; +} + +void Executable::print(raw_ostream &os) const { + os << (live ? "live" : "dead"); +} + +void Executable::onUpdate(DataFlowSolver *solver) const { + if (auto *block = point.dyn_cast()) { + // Re-invoke the analyses on the block itself. + for (DataFlowAnalysis *analysis : subscribers) + solver->enqueue({block, analysis}); + // Re-invoke the analyses on all operations in the block. + for (DataFlowAnalysis *analysis : subscribers) + for (Operation &op : *block) + solver->enqueue({&op, analysis}); + } else if (auto *programPoint = point.dyn_cast()) { + // Re-invoke the analysis on the successor block. + if (auto *edge = dyn_cast(programPoint)) + for (DataFlowAnalysis *analysis : subscribers) + solver->enqueue({edge->getTo(), analysis}); + } +} + +//===----------------------------------------------------------------------===// +// ConstantValue +//===----------------------------------------------------------------------===// + +void ConstantValue::print(raw_ostream &os) const { + if (constant) + return constant.print(os); + os << ""; +} + +//===----------------------------------------------------------------------===// +// PredecessorState +//===----------------------------------------------------------------------===// + +void PredecessorState::print(raw_ostream &os) const { + if (allPredecessorsKnown()) + os << "(all) "; + os << "predecessors:\n"; + for (Operation *op : getKnownPredecessors()) + os << " " << *op << "\n"; +} + +//===----------------------------------------------------------------------===// +// CFGEdge +//===----------------------------------------------------------------------===// + +Location CFGEdge::getLoc() const { + return FusedLoc::get( + getFrom()->getParent()->getContext(), + {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()}); +} + +void CFGEdge::print(raw_ostream &os) const { + getFrom()->print(os); + os << "\n -> \n"; + getTo()->print(os); +} + +//===----------------------------------------------------------------------===// +// DeadCodeAnalysis +//===----------------------------------------------------------------------===// + +DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver) + : DataFlowAnalysis(solver) { + registerPointKind(); +} + +LogicalResult DeadCodeAnalysis::initialize(Operation *top) { + // Mark the top-level blocks as executable. + for (Region ®ion : top->getRegions()) { + if (region.empty()) + continue; + auto *state = getOrCreate(®ion.front()); + propagateIfChanged(state, state->setToLive()); + } + + // Mark as overdefined the predecessors of symbol callables with potentially + // unknown predecessors. + initializeSymbolCallables(top); + + return initializeRecursively(top); +} + +void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { + auto walkFn = [&](Operation *symTable, bool allUsesVisible) { + Region &symbolTableRegion = symTable->getRegion(0); + Block *symbolTableBlock = &symbolTableRegion.front(); + + bool foundSymbolCallable = false; + for (auto callable : symbolTableBlock->getOps()) { + Region *callableRegion = callable.getCallableRegion(); + if (!callableRegion) + continue; + auto symbol = dyn_cast(callable.getOperation()); + if (!symbol) + continue; + + // Public symbol callables or those for which we can't see all uses have + // potentially unknown callsites. + if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { + auto *state = getOrCreate(callable); + propagateIfChanged(state, state->setHasUnknownPredecessors()); + } + foundSymbolCallable = true; + } + + // Exit early if no eligible symbol callables were found in the table. + if (!foundSymbolCallable) + return; + + // Walk the symbol table to check for non-call uses of symbols. + Optional uses = + SymbolTable::getSymbolUses(&symbolTableRegion); + if (!uses) { + // If we couldn't gather the symbol uses, conservatively assume that + // we can't track information for any nested symbols. + return top->walk([&](CallableOpInterface callable) { + auto *state = getOrCreate(callable); + propagateIfChanged(state, state->setHasUnknownPredecessors()); + }); + } + + for (const SymbolTable::SymbolUse &use : *uses) { + if (isa(use.getUser())) + continue; + // If a callable symbol has a non-call use, then we can't be guaranteed to + // know all callsites. + Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef()); + auto *state = getOrCreate(symbol); + propagateIfChanged(state, state->setHasUnknownPredecessors()); + } + }; + SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(), + walkFn); +} + +LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { + // Initialize the analysis by visiting every op with control-flow semantics. + if (op->getNumRegions() || op->getNumSuccessors() || + op->hasTrait() || isa(op)) { + // When the liveness of the parent block changes, make sure to re-invoke the + // analysis on the op. + if (op->getBlock()) + getOrCreate(op->getBlock())->blockContentSubscribe(this); + // Visit the op. + if (failed(visit(op))) + return failure(); + } + // Recurse on nested operations. + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (failed(initializeRecursively(&op))) + return failure(); + return success(); +} + +void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { + auto *state = getOrCreate(to); + propagateIfChanged(state, state->setToLive()); + auto *edgeState = getOrCreate(getProgramPoint(from, to)); + propagateIfChanged(edgeState, edgeState->setToLive()); +} + +void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { + for (Region ®ion : op->getRegions()) { + if (region.empty()) + continue; + auto *state = getOrCreate(®ion.front()); + propagateIfChanged(state, state->setToLive()); + } +} + +LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { + if (point.is()) + return success(); + auto *op = point.dyn_cast(); + if (!op) + return emitError(point.getLoc(), "unknown program point kind"); + + // If the parent block is not executable, there is nothing to do. + if (!getOrCreate(op->getBlock())->isLive()) + return success(); + + // We have a live call op. Add this as a live predecessor of the callee. + if (auto call = dyn_cast(op)) + visitCallOperation(call); + + // Visit the regions. + if (op->getNumRegions()) { + // Check if we can reason about the region control-flow. + if (auto branch = dyn_cast(op)) { + visitRegionBranchOperation(branch); + + // Check if this is a callable operation. + } else if (auto callable = dyn_cast(op)) { + const auto *callsites = getOrCreateFor(op, callable); + + // If the callsites could not be resolved or are known to be non-empty, + // mark the callable as executable. + if (!callsites->allPredecessorsKnown() || + !callsites->getKnownPredecessors().empty()) + markEntryBlocksLive(callable); + + // Otherwise, conservatively mark all entry blocks as executable. + } else { + markEntryBlocksLive(op); + } + } + + if (op->hasTrait() && !op->getNumSuccessors()) { + if (auto branch = dyn_cast(op->getParentOp())) { + // Visit the exiting terminator of a region. + visitRegionTerminator(op, branch); + } else if (auto callable = + dyn_cast(op->getParentOp())) { + // Visit the exiting terminator of a callable. + visitCallableTerminator(op, callable); + } + } + // Visit the successors. + if (op->getNumSuccessors()) { + // Check if we can reason about the control-flow. + if (auto branch = dyn_cast(op)) { + visitBranchOperation(branch); + + // Otherwise, conservatively mark all successors as exectuable. + } else { + for (Block *successor : op->getSuccessors()) + markEdgeLive(op->getBlock(), successor); + } + } + + return success(); +} + +void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { + Operation *callableOp = nullptr; + if (Value callableValue = call.getCallableForCallee().dyn_cast()) + callableOp = callableValue.getDefiningOp(); + else + callableOp = call.resolveCallable(&symbolTable); + + // A call to a externally-defined callable has unknown predecessors. + const auto isExternalCallable = [](Operation *op) { + if (auto callable = dyn_cast(op)) + return !callable.getCallableRegion(); + return false; + }; + + // TODO: Add support for non-symbol callables when necessary. If the + // callable has non-call uses we would mark as having reached pessimistic + // fixpoint, otherwise allow for propagating the return values out. + if (isa_and_nonnull(callableOp) && + !isExternalCallable(callableOp)) { + // Add the live callsite. + auto *callsites = getOrCreate(callableOp); + propagateIfChanged(callsites, callsites->join(call)); + } else { + // Mark this call op's predecessors as overdefined. + auto *predecessors = getOrCreate(call); + propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); + } +} + +/// Get the constant values of the operands of an operation. If any of the +/// constant value lattices are uninitialized, return none to indicate the +/// analysis should bail out. +static Optional> getOperandValuesImpl( + Operation *op, + function_ref *(Value)> getLattice) { + SmallVector operands; + operands.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + const Lattice *cv = getLattice(operand); + // If any of the operands' values are uninitialized, bail out. + if (cv->isUninitialized()) + return {}; + operands.push_back(cv->getValue().getConstantValue()); + } + return operands; +} + +Optional> +DeadCodeAnalysis::getOperandValues(Operation *op) { + return getOperandValuesImpl(op, [&](Value value) { + auto *lattice = getOrCreate>(value); + lattice->useDefSubscribe(this); + return lattice; + }); +} + +void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { + // Try to deduce a single successor for the branch. + Optional> operands = getOperandValues(branch); + if (!operands) + return; + + if (Block *successor = branch.getSuccessorForOperands(*operands)) { + markEdgeLive(branch->getBlock(), successor); + } else { + // Otherwise, mark all successors as executable and outgoing edges. + for (Block *successor : branch->getSuccessors()) + markEdgeLive(branch->getBlock(), successor); + } +} + +void DeadCodeAnalysis::visitRegionBranchOperation( + RegionBranchOpInterface branch) { + // Try to deduce which regions are executable. + Optional> operands = getOperandValues(branch); + if (!operands) + return; + + SmallVector successors; + branch.getSuccessorRegions(/*index=*/{}, *operands, successors); + + for (const RegionSuccessor &successor : successors) { + // Mark the entry block as executable. + Region *region = successor.getSuccessor(); + assert(region && "expected a region successor"); + auto *state = getOrCreate(®ion->front()); + propagateIfChanged(state, state->setToLive()); + // Add the parent op as a predecessor. + auto *predecessors = getOrCreate(region); + propagateIfChanged(predecessors, predecessors->join(branch)); + } +} + +void DeadCodeAnalysis::visitRegionTerminator(Operation *op, + RegionBranchOpInterface branch) { + Optional> operands = getOperandValues(op); + if (!operands) + return; + + SmallVector successors; + branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(), + *operands, successors); + + // Mark successor region entry blocks as executable and add this op to the + // list of predecessors. + for (const RegionSuccessor &successor : successors) { + PredecessorState *predecessors; + if (Region *region = successor.getSuccessor()) { + auto *state = getOrCreate(®ion->front()); + propagateIfChanged(state, state->setToLive()); + predecessors = getOrCreate(region); + } else { + // Add this terminator as a predecessor to the parent op. + predecessors = getOrCreate(branch); + } + propagateIfChanged(predecessors, predecessors->join(op)); + } +} + +void DeadCodeAnalysis::visitCallableTerminator(Operation *op, + CallableOpInterface callable) { + // If there are no exiting values, we have nothing to do. + if (op->getNumOperands() == 0) + return; + + // Add as predecessors to all callsites this return op. + auto *callsites = getOrCreateFor(op, callable); + bool canResolve = op->hasTrait(); + for (Operation *predecessor : callsites->getKnownPredecessors()) { + assert(isa(predecessor)); + auto *predecessors = getOrCreate(predecessor); + if (canResolve) { + propagateIfChanged(predecessors, predecessors->join(op)); + } else { + // If the terminator is not a return-like, then conservatively assume we + // can't resolve the predecessor. + propagateIfChanged(predecessors, + predecessors->setHasUnknownPredecessors()); + } + } +} diff --git a/mlir/test/Analysis/test-dead-code-analysis.mlir b/mlir/test/Analysis/test-dead-code-analysis.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-dead-code-analysis.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt -test-dead-code-analysis 2>&1 %s | FileCheck %s + +// CHECK: test_cfg: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: ^bb1 = live +// CHECK: from ^bb1 = live +// CHECK: from ^bb0 = live +// CHECK: ^bb2 = live +// CHECK: from ^bb1 = live +func.func @test_cfg(%cond: i1) -> () + attributes {tag = "test_cfg"} { + cf.br ^bb1 + +^bb1: + cf.cond_br %cond, ^bb1, ^bb2 + +^bb2: + return +} + +func.func @test_region_control_flow(%cond: i1, %arg0: i64, %arg1: i64) -> () { + // CHECK: test_if: + // CHECK: region #0 + // CHECK: region_preds: (all) predecessors: + // CHECK: scf.if + // CHECK: region #1 + // CHECK: region_preds: (all) predecessors: + // CHECK: scf.if + // CHECK: op_preds: (all) predecessors: + // CHECK: scf.yield {then} + // CHECK: scf.yield {else} + scf.if %cond { + scf.yield {then} + } else { + scf.yield {else} + } {tag = "test_if"} + + // test_while: + // region #0 + // region_preds: (all) predecessors: + // scf.while + // scf.yield + // region #1 + // region_preds: (all) predecessors: + // scf.condition + // op_preds: (all) predecessors: + // scf.condition + %c2_i64 = arith.constant 2 : i64 + %0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) { + %1 = arith.cmpi slt, %arg2, %arg1 : i64 + scf.condition(%1) %arg2, %arg2 : i64, i64 + } do { + ^bb0(%arg2: i64, %arg3: i64): + %1 = arith.muli %arg3, %c2_i64 : i64 + scf.yield %1 : i64 + } attributes {tag = "test_while"} + + return +} + +// CHECK: foo: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: op_preds: (all) predecessors: +// CHECK: func.call @foo(%{{.*}}) {tag = "a"} +// CHECK: func.call @foo(%{{.*}}) {tag = "b"} +func.func private @foo(%arg0: i32) -> i32 + attributes {tag = "foo"} { + return {a} %arg0 : i32 +} + +// CHECK: bar: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: op_preds: predecessors: +// CHECK: func.call @bar(%{{.*}}) {tag = "c"} +func.func @bar(%cond: i1) -> i32 + attributes {tag = "bar"} { + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + %c0 = arith.constant 0 : i32 + return {b} %c0 : i32 + +^bb2: + %c1 = arith.constant 1 : i32 + return {c} %c1 : i32 +} + +// CHECK: baz +// CHECK: op_preds: (all) predecessors: +func.func private @baz(i32) -> i32 attributes {tag = "baz"} + +func.func @test_callgraph(%cond: i1, %arg0: i32) -> i32 { + // CHECK: a: + // CHECK: op_preds: (all) predecessors: + // CHECK: func.return {a} + %0 = func.call @foo(%arg0) {tag = "a"} : (i32) -> i32 + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // CHECK: b: + // CHECK: op_preds: (all) predecessors: + // CHECK: func.return {a} + %1 = func.call @foo(%arg0) {tag = "b"} : (i32) -> i32 + return %1 : i32 + +^bb2: + // CHECK: c: + // CHECK: op_preds: (all) predecessors: + // CHECK: func.return {b} + // CHECK: func.return {c} + %2 = func.call @bar(%cond) {tag = "c"} : (i1) -> i32 + // CHECK: d: + // CHECK: op_preds: predecessors: + %3 = func.call @baz(%arg0) {tag = "d"} : (i32) -> i32 + return %2 : i32 +} + +// CHECK: test_unknown_branch: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: ^bb1 = live +// CHECK: from ^bb0 = live +// CHECK: ^bb2 = live +// CHECK: from ^bb0 = live +func.func @test_unknown_branch() -> () + attributes {tag = "test_unknown_branch"} { + "test.unknown_br"() [^bb1, ^bb2] : () -> () + +^bb1: + return + +^bb2: + return +} + +// CHECK: test_unknown_region: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: region #1 +// CHECK: ^bb0 = live +func.func @test_unknown_region() -> () { + "test.unknown_region_br"() ({ + ^bb0: + "test.unknown_region_end"() : () -> () + }, { + ^bb0: + "test.unknown_region_end"() : () -> () + }) {tag = "test_unknown_region"} : () -> () + return +} + +// CHECK: test_known_dead_block: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: ^bb1 = live +// CHECK: ^bb2 = dead +func.func @test_known_dead_block() -> () + attributes {tag = "test_known_dead_block"} { + %true = arith.constant true + cf.cond_br %true, ^bb1, ^bb2 + +^bb1: + return + +^bb2: + return +} + +// CHECK: test_known_dead_edge: +// CHECK: ^bb2 = live +// CHECK: from ^bb1 = dead +// CHECK: from ^bb0 = live +func.func @test_known_dead_edge(%arg0: i1) -> () + attributes {tag = "test_known_dead_edge"} { + cf.cond_br %arg0, ^bb1, ^bb2 + +^bb1: + %true = arith.constant true + cf.cond_br %true, ^bb3, ^bb2 + +^bb2: + return + +^bb3: + return +} + +func.func @test_known_region_predecessors() -> () { + %false = arith.constant false + // CHECK: test_known_if: + // CHECK: region #0 + // CHECK: ^bb0 = dead + // CHECK: region #1 + // CHECK: ^bb0 = live + // CHECK: region_preds: (all) predecessors: + // CHECK: scf.if + // CHECK: op_preds: (all) predecessors: + // CHECK: scf.yield {else} + scf.if %false { + scf.yield {then} + } else { + scf.yield {else} + } {tag = "test_known_if"} + return +} + +// CHECK: callable: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: op_preds: predecessors: +// CHECK: func.call @callable() {then} +func.func @callable() attributes {tag = "callable"} { + return +} + +func.func @test_dead_callsite() -> () { + %true = arith.constant true + scf.if %true { + func.call @callable() {then} : () -> () + scf.yield + } else { + func.call @callable() {else} : () -> () + scf.yield + } + return +} + +func.func private @test_dead_return(%arg0: i32) -> i32 { + %true = arith.constant true + cf.cond_br %true, ^bb1, ^bb1 + +^bb1: + return {true} %arg0 : i32 + +^bb2: + return {false} %arg0 : i32 +} + +func.func @test_call_dead_return(%arg0: i32) -> () { + // CHECK: test_dead_return: + // CHECK: op_preds: (all) predecessors: + // CHECK: func.return {true} + %0 = func.call @test_dead_return(%arg0) {tag = "test_dead_return"} : (i32) -> i32 + return +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ TestCallGraph.cpp TestDataFlow.cpp TestDataFlowFramework.cpp + TestDeadCodeAnalysis.cpp TestLiveness.cpp TestMatchReduction.cpp TestMemRefBoundCheck.cpp diff --git a/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp @@ -0,0 +1,116 @@ +//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SparseDataFlowAnalysis.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +/// Print the liveness of every block, control-flow edge, and the predecessors +/// of all regions, callables, and calls. +static void printAnalysisResults(DataFlowSolver &solver, Operation *op, + raw_ostream &os) { + op->walk([&](Operation *op) { + auto tag = op->getAttrOfType("tag"); + if (!tag) + return; + os << tag.getValue() << ":\n"; + for (Region ®ion : op->getRegions()) { + os << " region #" << region.getRegionNumber() << "\n"; + for (Block &block : region) { + os << " "; + block.printAsOperand(os); + os << " = "; + auto *live = solver.lookupState(&block); + if (live) + os << *live; + else + os << "dead"; + os << "\n"; + for (Block *pred : block.getPredecessors()) { + os << " from "; + pred->printAsOperand(os); + os << " = "; + auto *live = solver.lookupState( + solver.getProgramPoint(pred, &block)); + if (live) + os << *live; + else + os << "dead"; + os << "\n"; + } + } + if (!region.empty()) { + auto *preds = solver.lookupState(®ion); + if (preds) + os << "region_preds: " << *preds << "\n"; + } + } + auto *preds = solver.lookupState(op); + if (preds) + os << "op_preds: " << *preds << "\n"; + }); +} + +namespace { +/// This is a simple analysis that implements a transfer function for constant +/// operations. +struct ConstantAnalysis : public DataFlowAnalysis { + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { + if (op->hasTrait()) + if (failed(visit(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + + LogicalResult visit(ProgramPoint point) override { + Operation *op = point.get(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>(op->getResult(0)); + propagateIfChanged( + constant, constant->join(ConstantValue(value, op->getDialect()))); + } + return success(); + } +}; + +/// This is a simple pass that runs dead code analysis with no constant value +/// provider. It marks everything as live. +struct TestDeadCodeAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass) + + StringRef getArgument() const override { return "test-dead-code-analysis"; } + + void runOnOperation() override { + Operation *op = getOperation(); + + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + printAnalysisResults(solver, op, llvm::errs()); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestDeadCodeAnalysisPass() { + PassRegistration(); +} +} // end namespace test +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -72,6 +72,7 @@ void registerTestGpuSerializeToHsacoPass(); void registerTestDataFlowPass(); void registerTestDataLayoutQuery(); +void registerTestDeadCodeAnalysisPass(); void registerTestDecomposeCallGraphTypes(); void registerTestDiagnosticsPass(); void registerTestDominancePass(); @@ -172,6 +173,7 @@ mlir::test::registerTestDecomposeCallGraphTypes(); mlir::test::registerTestDataFlowPass(); mlir::test::registerTestDataLayoutQuery(); + mlir::test::registerTestDeadCodeAnalysisPass(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); mlir::test::registerTestExpandMathPass();