diff --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlowAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h @@ -22,34 +22,17 @@ #ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H #define MLIR_ANALYSIS_DATAFLOWANALYSIS_H +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" #include "llvm/Support/Allocator.h" -namespace mlir { -//===----------------------------------------------------------------------===// -// ChangeResult -//===----------------------------------------------------------------------===// - -/// A result type used to indicate if a change happened. Boolean operations on -/// ChangeResult behave as though `Change` is truthy. -enum class ChangeResult { - NoChange, - Change, -}; -inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) { - return lhs == ChangeResult::Change ? lhs : rhs; -} -inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) { - lhs = lhs | rhs; - return lhs; -} -inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) { - return lhs == ChangeResult::NoChange ? lhs : rhs; -} +/// TODO: Remove this file when SCCP and integer range analysis have been ported +/// to the new framework. +namespace mlir { //===----------------------------------------------------------------------===// // AbstractLatticeElement //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -16,7 +16,6 @@ #ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H #define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H -#include "mlir/Analysis/DataFlowAnalysis.h" #include "mlir/IR/Operation.h" #include "mlir/Support/StorageUniquer.h" #include "llvm/ADT/SetVector.h" @@ -25,6 +24,27 @@ namespace mlir { +//===----------------------------------------------------------------------===// +// ChangeResult +//===----------------------------------------------------------------------===// + +/// A result type used to indicate if a change happened. Boolean operations on +/// ChangeResult behave as though `Change` is truthy. +enum class ChangeResult { + NoChange, + Change, +}; +inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) { + return lhs == ChangeResult::Change ? lhs : rhs; +} +inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) { + lhs = lhs | rhs; + return lhs; +} +inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) { + return lhs == ChangeResult::NoChange ? lhs : rhs; +} + /// Forward declare the analysis state class. class AnalysisState; diff --git a/mlir/include/mlir/Analysis/DenseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/DenseDataFlowAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/DenseDataFlowAnalysis.h @@ -0,0 +1,167 @@ +//===- DenseDataFlowAnalysis.h - Dense data-flow analysis -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements dense data-flow analysis using the data-flow analysis +// framework. The analysis is forward and conditional and uses the results of +// dead code analysis to prune dead code during the analysis. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H +#define MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H + +#include "mlir/Analysis/SparseDataFlowAnalysis.h" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// AbstractDenseLattice +//===----------------------------------------------------------------------===// + +/// This class represents a dense lattice. A dense lattice is attached to +/// operations to represent the program state after their execution or to blocks +/// to represent the program state at the beginning of the block. A dense +/// lattice is propagated through the IR by dense data-flow analysis. +class AbstractDenseLattice : public AnalysisState { +public: + /// A dense lattice can only be created for operations and blocks. + using AnalysisState::AnalysisState; + + /// Join the lattice across control-flow or callgraph edges. + virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0; + + /// Reset the dense lattice to a pessimistic value. This occurs when the + /// analysis cannot reason about the data-flow. + virtual ChangeResult reset() = 0; + + /// Returns true if the lattice state has reached a pessimistic fixpoint. That + /// is, no further modifications to the lattice can occur. + virtual bool isAtFixpoint() const = 0; +}; + +//===----------------------------------------------------------------------===// +// AbstractDenseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// Base class for dense data-flow analyses. Dense data-flow analysis attaches a +/// lattice between the execution of operations and implements a transfer +/// function from the lattice before each operation to the lattice after. The +/// lattice contains information about the state of the program at that point. +/// +/// In this implementation, a lattice attached to an operation represents the +/// state of the program after its execution, and a lattice attached to block +/// represents the state of the program right before it starts executing its +/// body. +class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis { +public: + using DataFlowAnalysis::DataFlowAnalysis; + + /// Initialize the analysis by visiting every program point whose execution + /// may modify the program state; that is, every operation and block. + LogicalResult initialize(Operation *top) override; + + /// Visit a program point that modifies the state of the program. If this is a + /// block, then the state is propagated from control-flow predecessors or + /// callsites. If this is a call operation or region control-flow operation, + /// then the state after the execution of the operation is set by control-flow + /// or the callgraph. Otherwise, this function invokes the operation transfer + /// function. + LogicalResult visit(ProgramPoint point) override; + +protected: + /// Propagate the dense lattice before the execution of an operation to the + /// lattice after its execution. + virtual void visitOperationImpl(Operation *op, + const AbstractDenseLattice &before, + AbstractDenseLattice *after) = 0; + + /// Get the dense lattice after the execution of the given program point. + virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0; + + /// Get the dense lattice after the execution of the given program point and + /// add it as a dependency to a program point. + const AbstractDenseLattice *getLatticeFor(ProgramPoint dependee, + ProgramPoint point); + + /// Mark the dense lattice as having reached its pessimistic fixpoint and + /// propagate an update if it changed. + void reset(AbstractDenseLattice *lattice) { + propagateIfChanged(lattice, lattice->reset()); + } + + /// Join a lattice with another and propagate an update if it changed. + void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) { + propagateIfChanged(lhs, lhs->join(rhs)); + } + +private: + /// Visit an operation. If this is a call operation or region control-flow + /// operation, then the state after the execution of the operation is set by + /// control-flow or the callgraph. Otherwise, this function invokes the + /// operation transfer function. + void visitOperation(Operation *op); + + /// Visit a block. The state at the start of the block is propagated from + /// control-flow predecessors or callsites + void visitBlock(Block *block); + + /// Visit a program point within a region branch operation with predecessors + /// in it. This can either be an entry block of one of the regions of the + /// parent operation itself. + void visitRegionBranchOperation(ProgramPoint point, + RegionBranchOpInterface branch, + AbstractDenseLattice *after); +}; + +//===----------------------------------------------------------------------===// +// DenseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// A dense (forward) data-flow analysis for propagating lattices before and +/// after the execution of every operation across the IR by implementing +/// transfer functions for operations. +/// +/// `StateT` is expected to be a subclass of `AbstractDenseLattice`. +template +class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis { +public: + using AbstractDenseDataFlowAnalysis::AbstractDenseDataFlowAnalysis; + + /// Visit an operation with the dense lattice before its execution. This + /// function is expected to set the dense lattice after its execution. + virtual void visitOperation(Operation *op, const LatticeT &before, + LatticeT *after) = 0; + +protected: + /// Get the dense lattice after this program point. + LatticeT *getLattice(ProgramPoint point) override { + return getOrCreate(point); + } + +private: + /// Type-erased wrappers that convert the abstract dense lattice to a derived + /// lattice and invoke the virtual hooks operating on the derived lattice. + void visitOperationImpl(Operation *op, const AbstractDenseLattice &before, + AbstractDenseLattice *after) override { + visitOperation(op, static_cast(before), + static_cast(after)); + } +}; + +//===----------------------------------------------------------------------===// +// DenseLattice +//===----------------------------------------------------------------------===// + +template +class DenseLattice : public AbstractDenseLattice { +public: +}; + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H diff --git a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h @@ -0,0 +1,584 @@ +//===- SparseDataFlowAnalysis.h - Sparse data-flow analysis ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements sparse data-flow analysis using the data-flow analysis +// framework. The analysis is forward and conditional and uses the results of +// dead code analysis to prune dead code during the analysis. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H +#define MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/SmallPtrSet.h" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// AbstractLattice +//===----------------------------------------------------------------------===// + +/// This class represents an abstract lattice. A lattice contains information +/// about an SSA value and is what's propagated across the IR by sparse +/// data-flow analysis. +class AbstractLattice : public AnalysisState { +public: + /// Lattices can only be created for values. + AbstractLattice(Value value) : AnalysisState(value) {} + + /// Join the information contained in 'rhs' into this lattice. Returns + /// if the value of the lattice changed. + virtual ChangeResult join(const AbstractLattice &rhs) = 0; + + /// Returns true if the lattice element is at fixpoint and further calls to + /// `join` will not update the value of the element. + virtual bool isAtFixpoint() const = 0; + + /// Mark the lattice element as having reached a pessimistic fixpoint. This + /// means that the lattice may potentially have conflicting value states, and + /// only the most conservative value should be relied on. + virtual ChangeResult markPessimisticFixpoint() = 0; + + /// Mark the lattice element as having reached an optimistic fixpoint. This + /// means that we optimistically assume the current value is the true state. + virtual void markOptimisticFixpoint() = 0; + + /// When the lattice gets updated, propagate an update to users of the value + /// using its use-def chain to subscribed analyses. + void onUpdate(DataFlowSolver *solver) const override; + + /// Subscribe an analysis to updates of the lattice. When the lattice changes, + /// subscribed analyses are re-invoked on all users of the value. This is + /// more efficient than relying on the dependency map. + void useDefSubscribe(DataFlowAnalysis *analysis) { + useDefSubscribers.insert(analysis); + } + +private: + /// A set of analyses that should be updated when this lattice changes. + SetVector, + SmallPtrSet> + useDefSubscribers; +}; + +//===----------------------------------------------------------------------===// +// Lattice +//===----------------------------------------------------------------------===// + +/// This class represents a lattice holding a specific value of type `ValueT`. +/// Lattice values (`ValueT`) are required to adhere to the following: +/// +/// * static ValueT join(const ValueT &lhs, const ValueT &rhs); +/// - This method conservatively joins the information held by `lhs` +/// and `rhs` into a new value. This method is required to be monotonic. +/// * bool operator==(const ValueT &rhs) const; +/// +template +class Lattice : public AbstractLattice { +public: + /// Construct a lattice with a known value. + explicit Lattice(Value value) + : AbstractLattice(value), + knownValue(ValueT::getPessimisticValueState(value)) {} + + /// Return the value held by this lattice. This requires that the value is + /// initialized. + ValueT &getValue() { + assert(!isUninitialized() && "expected known lattice element"); + return *optimisticValue; + } + const ValueT &getValue() const { + return const_cast *>(this)->getValue(); + } + + /// Returns true if the value of this lattice hasn't yet been initialized. + bool isUninitialized() const override { return !optimisticValue.hasValue(); } + /// Force the initialization of the element by setting it to its pessimistic + /// fixpoint. + ChangeResult defaultInitialize() override { + return markPessimisticFixpoint(); + } + + /// Returns true if the lattice has reached a fixpoint. A fixpoint is when + /// the information optimistically assumed to be true is the same as the + /// information known to be true. + bool isAtFixpoint() const override { return optimisticValue == knownValue; } + + /// Join the information contained in the 'rhs' lattice into this + /// lattice. Returns if the state of the current lattice changed. + ChangeResult join(const AbstractLattice &rhs) override { + const Lattice &rhsLattice = + static_cast &>(rhs); + + // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do. + if (isAtFixpoint() || rhsLattice.isUninitialized()) + return ChangeResult::NoChange; + + // Join the rhs value into this lattice. + return join(rhsLattice.getValue()); + } + + /// Join the information contained in the 'rhs' value into this + /// lattice. Returns if the state of the current lattice changed. + ChangeResult join(const ValueT &rhs) { + // If the current lattice is uninitialized, copy the rhs value. + if (isUninitialized()) { + optimisticValue = rhs; + return ChangeResult::Change; + } + + // Otherwise, join rhs with the current optimistic value. + ValueT newValue = ValueT::join(*optimisticValue, rhs); + assert(ValueT::join(newValue, *optimisticValue) == newValue && + "expected `join` to be monotonic"); + assert(ValueT::join(newValue, rhs) == newValue && + "expected `join` to be monotonic"); + + // Update the current optimistic value if something changed. + if (newValue == optimisticValue) + return ChangeResult::NoChange; + + optimisticValue = newValue; + return ChangeResult::Change; + } + + /// Mark the lattice element as having reached a pessimistic fixpoint. This + /// means that the lattice may potentially have conflicting value states, + /// and only the conservatively known value state should be relied on. + ChangeResult markPessimisticFixpoint() override { + if (isAtFixpoint()) + return ChangeResult::NoChange; + + // For this fixed point, we take whatever we knew to be true and set that + // to our optimistic value. + optimisticValue = knownValue; + return ChangeResult::Change; + } + + /// Mark the lattice element as having reached an optimistic fixpoint. This + /// means that we optimistically assume the current value is the true state. + void markOptimisticFixpoint() override { + assert(!isUninitialized() && "expected an initialized value"); + knownValue = *optimisticValue; + } + + /// Print the lattice element. + void print(raw_ostream &os) const override { + os << "["; + knownValue.print(os); + os << ", "; + if (optimisticValue) { + optimisticValue->print(os); + } else { + os << ""; + } + os << "]"; + } + +private: + /// The value that is conservatively known to be true. + ValueT knownValue; + /// The currently computed value that is optimistically assumed to be true, + /// or None if the lattice element is uninitialized. + Optional optimisticValue; +}; + +//===----------------------------------------------------------------------===// +// Executable +//===----------------------------------------------------------------------===// + +/// This is a simple analysis state that represents whether the associated +/// program point (either a block or a control-flow edge) is live. +class Executable : public AnalysisState { +public: + using AnalysisState::AnalysisState; + + /// The state is initialized by default. + bool isUninitialized() const override { return false; } + + /// The state is always initialized. + ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } + + /// Set the state of the program point to live. + ChangeResult setToLive(); + + /// Get whether the program point is live. + bool isLive() const { return live; } + + /// Print the liveness; + void print(raw_ostream &os) const override; + + /// When the state of the program point is changed to live, re-invoke + /// subscribed analyses on the operations in the block and on the block + /// itself. + void onUpdate(DataFlowSolver *solver) const override; + + /// Subscribe an analysis to changes to the liveness. + void blockContentSubscribe(DataFlowAnalysis *analysis) { + subscribers.insert(analysis); + } + +private: + /// Whether the program point is live. Optimistically assume that the program + /// point is dead. + bool live = false; + + /// A set of analyses that should be updated when this state changes. + SetVector, + SmallPtrSet> + subscribers; +}; + +//===----------------------------------------------------------------------===// +// ConstantValue +//===----------------------------------------------------------------------===// + +/// This lattice value represents a known constant value of a lattice. +class ConstantValue { +public: + /// The pessimistic value state of the constant value is unknown. + static ConstantValue getPessimisticValueState(Value value) { return {}; } + + /// Construct a constant value with a known constant. + ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr) + : constant(knownValue), dialect(dialect) {} + + /// Get the constant value. Returns null if no value was determined. + Attribute getConstantValue() const { return constant; } + + /// Get the dialect instance that can be used to materialize the constant. + Dialect *getConstantDialect() const { return dialect; } + + /// Compare the constant values. + bool operator==(const ConstantValue &rhs) const { + return constant == rhs.constant; + } + + /// The union with another constant value is null if they are different, and + /// the same if they are the same. + static ConstantValue join(const ConstantValue &lhs, + const ConstantValue &rhs) { + return lhs == rhs ? lhs : ConstantValue(); + } + + /// Print the constant value. + void print(raw_ostream &os) const; + +private: + /// The constant value. + Attribute constant; + /// An dialect instance that can be used to materialize the constant. + Dialect *dialect; +}; + +//===----------------------------------------------------------------------===// +// PredecessorState +//===----------------------------------------------------------------------===// + +/// This analysis state represents a set of known predecessors. This state is +/// used in sparse data-flow analysis to reason about region control-flow and +/// callgraphs. The state may also indicate that not all predecessors can be +/// known, if for example not all callsites of a callable are visible. +class PredecessorState : public AnalysisState { +public: + using AnalysisState::AnalysisState; + + /// The state is initialized by default. + bool isUninitialized() const override { return false; } + + /// The state is always initialized. + ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } + + /// Print the known predecessors. + void print(raw_ostream &os) const override; + + /// Returns true if all predecessors are known. + bool allPredecessorsKnown() const { return allKnown; } + + /// Indicate that there are potentially unknown predecessors. + ChangeResult setHasUnknownPredecessors() { + if (!allKnown) + return ChangeResult::NoChange; + allKnown = false; + return ChangeResult::Change; + } + + /// Get the known predecessors. + ArrayRef getKnownPredecessors() const { + return predecessors.getArrayRef(); + } + + /// Get the successor inputs from a predecessor. + ValueRange getSuccessorInputs(Operation *predecessor) const { + return successorInputs.lookup(predecessor); + } + + /// Add a known predecessor. + ChangeResult join(Operation *predecessor); + + /// Add a known predecessor with successor inputs. + ChangeResult join(Operation *predecessor, ValueRange inputs); + +private: + /// Whether all predecessors are known. Optimistically assume that we known + /// all predecessors. + bool allKnown = true; + + /// The known control-flow predecessors of this program point. + SetVector, + SmallPtrSet> + predecessors; + + /// The successor inputs when branching from a given predecessor. + DenseMap successorInputs; +}; + +//===----------------------------------------------------------------------===// +// CFGEdge +//===----------------------------------------------------------------------===// + +/// This program point represents a control-flow edge between a block and one +/// of its successors. +class CFGEdge + : public GenericProgramPointBase> { +public: + using Base::Base; + + /// Get the block from which the edge originates. + Block *getFrom() const { return getValue().first; } + /// Get the target block. + Block *getTo() const { return getValue().second; } + + /// Print the blocks between the control-flow edge. + void print(raw_ostream &os) const override; + /// Get a fused location of both blocks. + Location getLoc() const override; +}; + +//===----------------------------------------------------------------------===// +// DeadCodeAnalysis +//===----------------------------------------------------------------------===// + +/// Dead code analysis analyzes control-flow, as understood by +/// `RegionBranchOpInterface` and `BranchOpInterface`, and the callgraph, as +/// understood by `CallableOpInterface` and `CallOpInterface`. +/// +/// This analysis uses known constant values of operands to determine the +/// liveness of each block and each edge between a block and its predecessors. +/// For region control-flow, this analysis determines the predecessor operations +/// for region entry blocks and region control-flow operations. For the +/// callgraph, this analysis determines the callsites and live returns of every +/// function. +class DeadCodeAnalysis : public DataFlowAnalysis { +public: + explicit DeadCodeAnalysis(DataFlowSolver &solver); + + /// Initialize the analysis by visiting every operation with potential + /// control-flow semantics. + LogicalResult initialize(Operation *top) override; + + /// Visit an operation with control-flow semantics and deduce which of its + /// successors are live. + LogicalResult visit(ProgramPoint point) override; + +private: + /// Find and mark symbol callables with potentially unknown callsites as + /// having overdefined predecessors. `top` is the top-level operation that the + /// analysis is operating on. + void initializeSymbolCallables(Operation *top); + + /// Recursively Initialize the analysis on nested regions. + LogicalResult initializeRecursively(Operation *op); + + /// Visit the given call operation and compute any necessary lattice state. + void visitCallOperation(CallOpInterface call); + + /// Visit the given branch operation with successors and try to determine + /// which are live from the current block. + void visitBranchOperation(BranchOpInterface branch); + + /// Visit the given region branch operation, which defines regions, and + /// compute any necessary lattice state. This also resolves the lattice state + /// of both the operation results and any nested regions. + void visitRegionBranchOperation(RegionBranchOpInterface branch); + + /// Visit the given terminator operation that exits a region under an + /// operation with control-flow semantics. These are terminators with no CFG + /// successors. + void visitRegionTerminator(Operation *op, RegionBranchOpInterface branch); + + /// Visit the given terminator operation that exits a callable region. These + /// are terminators with no CFG successors. + void visitCallableTerminator(Operation *op, CallableOpInterface callable); + + /// Mark the edge between `from` and `to` as executable. + void markEdgeLive(Block *from, Block *to); + + /// Mark the entry blocks of the operation as executable. + void markEntryBlocksLive(Operation *op); + + /// Get the constant values of the operands of the operation. Returns none if + /// any of the operand lattices are uninitialized. + Optional> getOperandValues(Operation *op); + + /// A symbol table used for O(1) symbol lookups during simplification. + SymbolTableCollection symbolTable; +}; + +//===----------------------------------------------------------------------===// +// AbstractSparseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// Base class for sparse (forward) data-flow analyses. A sparse analysis +/// implements a transfer function on operations from the lattices of the +/// operands to the lattices of the results. This analysis will propagate +/// lattices across control-flow edges and the callgraph using liveness +/// information. +class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis { +public: + /// Initialize the analysis by visiting every owner of an SSA value: all + /// operations and blocks. + LogicalResult initialize(Operation *top) override; + + /// Visit a program point. If this is a block and all control-flow + /// predecessors or callsites are known, then the arguments lattices are + /// propagated from them. If this is a call operation or an operation with + /// region control-flow, then its result lattices are set accordingly. + /// Otherwise, the operation transfer function is invoked. + LogicalResult visit(ProgramPoint point) override; + +protected: + explicit AbstractSparseDataFlowAnalysis(DataFlowSolver &solver); + + /// The operation transfer function. Given the operand lattices, this + /// function is expected to set the result lattices. + virtual void + visitOperationImpl(Operation *op, + ArrayRef operandLattices, + ArrayRef resultLattices) = 0; + + /// Get the lattice element of a value. + virtual AbstractLattice *getLatticeElement(Value value) = 0; + + /// Get a read-only lattice element for a value and add it as a dependency to + /// a program point. + const AbstractLattice *getLatticeElementFor(ProgramPoint point, Value value); + + /// Mark a lattice element as having reached its pessimistic fixpoint and + /// propgate an update if changed. + void markPessimisticFixpoint(AbstractLattice *lattice); + + /// Mark the given lattice elements as having reached their pessimistic + /// fixpoints and propagate an update if any changed. + void markAllPessimisticFixpoint(ArrayRef lattices); + + /// Join the lattice element and propagate and update if it changed. + void join(AbstractLattice *lhs, const AbstractLattice &rhs); + +private: + /// Recursively initialize the analysis on nested operations and blocks. + LogicalResult initializeRecursively(Operation *op); + + /// Visit an operation. If this is a call operation or an operation with + /// region control-flow, then its result lattices are set accordingly. + /// Otherwise, the operation transfer function is invoked. + void visitOperation(Operation *op); + + /// If this is a block and all control-flow predecessors or callsites are + /// known, then the arguments lattices are propagated from them. + void visitBlock(Block *block); + + /// Visit a program point `point` with predecessors within a region branch + /// operation `branch`, which can either be the entry block of one of the + /// regions or the parent operation itself, and set either the argument or + /// parent result lattices. + void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch, + Optional successorIndex, + ArrayRef lattices); +}; + +//===----------------------------------------------------------------------===// +// SparseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// A sparse (forward) data-flow analysis for propagating SSA value lattices +/// across the IR by implementing transfer functions for operations. +/// +/// `StateT` is expected to be a subclass of `AbstractLatticeElement`. +template +class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis { +public: + explicit SparseDataFlowAnalysis(DataFlowSolver &solver) + : AbstractSparseDataFlowAnalysis(solver) {} + + /// Visit an operation with the lattices of its operands. This function is + /// expected to set the lattices of the operation's results. + virtual void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) = 0; + +protected: + /// Get the lattice element for a value. + StateT *getLatticeElement(Value value) override { + return getOrCreate(value); + } + + /// Get the lattice element for a value and create a dependency on the + /// provided program point. + const StateT *getLatticeElementFor(ProgramPoint point, Value value) { + return static_cast( + AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value)); + } + + /// Mark the lattice elements of a range of values as having reached their + /// pessimistic fixpoint. + void markAllPessimisticFixpoint(ArrayRef lattices) { + AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( + {reinterpret_cast(lattices.begin()), + lattices.size()}); + } + +private: + /// Type-erased wrappers that convert the abstract lattice operands to derived + /// lattices and invoke the virtual hooks operating on the derived lattices. + void visitOperationImpl(Operation *op, + ArrayRef operandLattices, + ArrayRef resultLattices) override { + visitOperation( + op, + {reinterpret_cast(operandLattices.begin()), + operandLattices.size()}, + {reinterpret_cast(resultLattices.begin()), + resultLattices.size()}); + } +}; + +//===----------------------------------------------------------------------===// +// SparseConstantPropagation +//===----------------------------------------------------------------------===// + +/// This analysis implements sparse constant propagation, which attempts to +/// determine constant-valued results for operations using constant-valued +/// operands, by speculatively folding operations. When combined with dead-code +/// analysis, this becomes sparse conditional constant propagation (SCCP). +class SparseConstantPropagation + : public SparseDataFlowAnalysis> { +public: + using SparseDataFlowAnalysis::SparseDataFlowAnalysis; + + void visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -6,6 +6,7 @@ DataLayoutAnalysis.cpp Liveness.cpp SliceAnalysis.cpp + SparseDataFlowAnalysis.cpp AliasAnalysis/LocalAliasAnalysis.cpp ) @@ -17,8 +18,10 @@ DataFlowAnalysis.cpp DataFlowFramework.cpp DataLayoutAnalysis.cpp + DenseDataFlowAnalysis.cpp Liveness.cpp SliceAnalysis.cpp + SparseDataFlowAnalysis.cpp AliasAnalysis/LocalAliasAnalysis.cpp diff --git a/mlir/lib/Analysis/DenseDataFlowAnalysis.cpp b/mlir/lib/Analysis/DenseDataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/DenseDataFlowAnalysis.cpp @@ -0,0 +1,164 @@ +//===- DenseDataFlowAnalysis.cpp - Dense data-flow analysis ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DenseDataFlowAnalysis.h" + +using namespace mlir; + +LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) { + // Visit every operation and block. + visitOperation(top); + for (Region ®ion : top->getRegions()) { + for (Block &block : region) { + visitBlock(&block); + for (Operation &op : block) + if (failed(initialize(&op))) + return failure(); + } + } + return success(); +} + +LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) { + if (auto *op = point.dyn_cast()) + visitOperation(op); + else if (auto *block = point.dyn_cast()) + visitBlock(block); + else + return failure(); + return success(); +} + +void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) { + // If the containing block is not executable, bail out. + if (!getOrCreateFor(op, op->getBlock())->isLive()) + return; + + // Get the dense lattice to update. + AbstractDenseLattice *after = getLattice(op); + if (after->isAtFixpoint()) + return; + + // If this op implements region control-flow, then control-flow dictates its + // transfer function. + if (auto branch = dyn_cast(op)) + return visitRegionBranchOperation(op, branch, after); + + // If this is a call operation, then join its lattices across known return + // sites. + if (auto call = dyn_cast(op)) { + const auto *predecessors = getOrCreateFor(op, call); + // If not all return sites are known, then conservatively assume we can't + // reason about the data-flow. + if (!predecessors->allPredecessorsKnown()) + return reset(after); + for (Operation *predecessor : predecessors->getKnownPredecessors()) + join(after, *getLatticeFor(op, predecessor)); + return; + } + + // Get the dense state before the execution of the op. + const AbstractDenseLattice *before; + if (Operation *prev = op->getPrevNode()) + before = getLatticeFor(op, prev); + else + before = getLatticeFor(op, op->getBlock()); + // If the incoming lattice is uninitialized, bail out. + if (before->isUninitialized()) + return; + + // Invoke the operation transfer function. + visitOperationImpl(op, *before, after); +} + +void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) { + // If the block is not executable, bail out. + if (!getOrCreateFor(block, block)->isLive()) + return; + + // Get the dense lattice to update. + AbstractDenseLattice *after = getLattice(block); + if (after->isAtFixpoint()) + return; + + // The dense lattices of entry blocks are set by region control-flow or the + // callgraph. + if (block->isEntryBlock()) { + // Check if this block is the entry block of a callable region. + auto callable = dyn_cast(block->getParentOp()); + if (callable && callable.getCallableRegion() == block->getParent()) { + const auto *callsites = getOrCreateFor(block, callable); + // If not all callsites are known, conservatively mark all lattices as + // having reached their pessimistic fixpoints. + if (!callsites->allPredecessorsKnown()) + return reset(after); + for (Operation *callsite : callsites->getKnownPredecessors()) { + // Get the dense lattice before the callsite. + if (Operation *prev = callsite->getPrevNode()) + join(after, *getLatticeFor(block, prev)); + else + join(after, *getLatticeFor(block, callsite->getBlock())); + } + return; + } + + // Check if we can reason about the control-flow. + if (auto branch = dyn_cast(block->getParentOp())) + return visitRegionBranchOperation(block, branch, after); + + // Otherwise, we can't reason about the data-flow. + return reset(after); + } + + // Join the state with the state after the block's predecessors. + for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); + it != e; ++it) { + // Skip control edges that aren't executable. + Block *predecessor = *it; + if (!getOrCreateFor( + block, getProgramPoint(predecessor, block)) + ->isLive()) + continue; + + // Merge in the state from the predecessor's terminator. + join(after, *getLatticeFor(block, predecessor->getTerminator())); + } +} + +void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation( + ProgramPoint point, RegionBranchOpInterface branch, + AbstractDenseLattice *after) { + // Get the terminator predecessors. + const auto *predecessors = getOrCreateFor(point, point); + assert(predecessors->allPredecessorsKnown() && + "unexpected unresolved region successors"); + + for (Operation *op : predecessors->getKnownPredecessors()) { + const AbstractDenseLattice *before; + // If the predecessor is the parent, get the state before the parent. + if (op == branch) { + if (Operation *prev = op->getPrevNode()) + before = getLatticeFor(point, prev); + else + before = getLatticeFor(point, op->getBlock()); + + // Otherwise, get the state after the terminator. + } else { + before = getLatticeFor(point, op); + } + join(after, *before); + } +} + +const AbstractDenseLattice * +AbstractDenseDataFlowAnalysis::getLatticeFor(ProgramPoint dependee, + ProgramPoint point) { + AbstractDenseLattice *state = getLattice(point); + addDependency(state, dependee); + return state; +} diff --git a/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp @@ -0,0 +1,736 @@ +//===- SparseDataFlowAnalysis.cpp - Sparse data-flow analysis -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SparseDataFlowAnalysis.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "dataflow" + +using namespace mlir; + +void AbstractLattice::onUpdate(DataFlowSolver *solver) const { + // Push all users of the value to the queue. + for (Operation *user : point.get().getUsers()) + for (DataFlowAnalysis *analysis : useDefSubscribers) + solver->enqueue({user, analysis}); +} + +ChangeResult Executable::setToLive() { + if (live) + return ChangeResult::NoChange; + live = true; + return ChangeResult::Change; +} + +void Executable::print(raw_ostream &os) const { + os << (live ? "live" : "dead"); +} + +Location CFGEdge::getLoc() const { + return FusedLoc::get( + getFrom()->getParent()->getContext(), + {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()}); +} + +void CFGEdge::print(raw_ostream &os) const { + getFrom()->print(os); + os << "\n -> \n"; + getTo()->print(os); +} + +void PredecessorState::print(raw_ostream &os) const { + if (allPredecessorsKnown()) + os << "(all) "; + os << "predecessors:\n"; + for (Operation *op : getKnownPredecessors()) + os << " " << *op << "\n"; +} + +ChangeResult PredecessorState::join(Operation *predecessor) { + return predecessors.insert(predecessor) ? ChangeResult::Change + : ChangeResult::NoChange; +} + +ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) { + ChangeResult result = join(predecessor); + if (!inputs.empty()) { + ValueRange &curInputs = successorInputs[predecessor]; + if (curInputs != inputs) { + curInputs = inputs; + result |= ChangeResult::Change; + } + } + return result; +} + +DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver) + : DataFlowAnalysis(solver) { + registerPointKind(); +} + +LogicalResult DeadCodeAnalysis::initialize(Operation *top) { + // Mark the top-level blocks as executable. + for (Region ®ion : top->getRegions()) { + if (region.empty()) + continue; + auto *state = getOrCreate(®ion.front()); + propagateIfChanged(state, state->setToLive()); + } + + // Mark as overdefined the predecessors of symbol callables with potentially + // unknown predecessors. + initializeSymbolCallables(top); + + return initializeRecursively(top); +} + +void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { + auto walkFn = [&](Operation *symTable, bool allUsesVisible) { + Region &symbolTableRegion = symTable->getRegion(0); + Block *symbolTableBlock = &symbolTableRegion.front(); + + bool foundSymbolCallable = false; + for (auto callable : symbolTableBlock->getOps()) { + Region *callableRegion = callable.getCallableRegion(); + if (!callableRegion) + continue; + auto symbol = dyn_cast(callable.getOperation()); + if (!symbol) + continue; + + // Public symbol callables or those for which we can't see all uses have + // potentially unknown callsites. + if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { + auto *state = getOrCreate(callable); + propagateIfChanged(state, state->setHasUnknownPredecessors()); + } + foundSymbolCallable = true; + } + + // Exit early if no eligible symbol callables were found in the table. + if (!foundSymbolCallable) + return; + + // Walk the symbol table to check for non-call uses of symbols. + Optional uses = + SymbolTable::getSymbolUses(&symbolTableRegion); + if (!uses) { + // If we couldn't gather the symbol uses, conservatively assume that + // we can't track information for any nested symbols. + return top->walk([&](CallableOpInterface callable) { + auto *state = getOrCreate(callable); + propagateIfChanged(state, state->setHasUnknownPredecessors()); + }); + } + + for (const SymbolTable::SymbolUse &use : *uses) { + if (isa(use.getUser())) + continue; + // If a callable symbol has a non-call use, then we can't be guaranteed to + // know all callsites. + Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef()); + auto *state = getOrCreate(symbol); + propagateIfChanged(state, state->setHasUnknownPredecessors()); + } + }; + SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(), + walkFn); +} + +LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { + // Initialize the analysis by visiting every op with control-flow semantics. + if (op->getNumRegions() || op->getNumSuccessors() || + op->hasTrait() || isa(op)) { + // When the liveness of the parent block changes, make sure to re-invoke the + // analysis on the op. + if (op->getBlock()) + getOrCreate(op->getBlock())->blockContentSubscribe(this); + // Visit the op. + if (failed(visit(op))) + return failure(); + } + // Recurse on nested operations. + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (failed(initializeRecursively(&op))) + return failure(); + return success(); +} + +void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { + auto *state = getOrCreate(to); + propagateIfChanged(state, state->setToLive()); + auto *edgeState = getOrCreate(getProgramPoint(from, to)); + propagateIfChanged(edgeState, edgeState->setToLive()); +} + +void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { + for (Region ®ion : op->getRegions()) { + if (region.empty()) + continue; + auto *state = getOrCreate(®ion.front()); + propagateIfChanged(state, state->setToLive()); + } +} + +LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { + if (point.is()) + return success(); + auto *op = point.dyn_cast(); + if (!op) + return emitError(point.getLoc(), "unknown program point kind"); + + // If the parent block is not executable, there is nothing to do. + if (!getOrCreate(op->getBlock())->isLive()) + return success(); + + // We have a live call op. Add this as a live predecessor of the callee. + if (auto call = dyn_cast(op)) + visitCallOperation(call); + + // Visit the regions. + if (op->getNumRegions()) { + // Check if we can reason about the region control-flow. + if (auto branch = dyn_cast(op)) { + visitRegionBranchOperation(branch); + + // Check if this is a callable operation. + } else if (auto callable = dyn_cast(op)) { + const auto *callsites = getOrCreateFor(op, callable); + + // If the callsites could not be resolved or are known to be non-empty, + // mark the callable as executable. + if (!callsites->allPredecessorsKnown() || + !callsites->getKnownPredecessors().empty()) + markEntryBlocksLive(callable); + + // Otherwise, conservatively mark all entry blocks as executable. + } else { + markEntryBlocksLive(op); + } + } + + if (op->hasTrait() && !op->getNumSuccessors()) { + if (auto branch = dyn_cast(op->getParentOp())) { + // Visit the exiting terminator of a region. + visitRegionTerminator(op, branch); + } else if (auto callable = + dyn_cast(op->getParentOp())) { + // Visit the exiting terminator of a callable. + visitCallableTerminator(op, callable); + } + } + // Visit the successors. + if (op->getNumSuccessors()) { + // Check if we can reason about the control-flow. + if (auto branch = dyn_cast(op)) { + visitBranchOperation(branch); + + // Otherwise, conservatively mark all successors as exectuable. + } else { + for (Block *successor : op->getSuccessors()) + markEdgeLive(op->getBlock(), successor); + } + } + + return success(); +} + +void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { + Operation *callableOp = nullptr; + if (Value callableValue = call.getCallableForCallee().dyn_cast()) + callableOp = callableValue.getDefiningOp(); + else + callableOp = call.resolveCallable(&symbolTable); + + // A call to a externally-defined callable has unknown predecessors. + const auto isExternalCallable = [](Operation *op) { + if (auto callable = dyn_cast(op)) + return !callable.getCallableRegion(); + return false; + }; + + // TODO: Add support for non-symbol callables when necessary. If the + // callable has non-call uses we would mark as having reached pessimistic + // fixpoint, otherwise allow for propagating the return values out. + if (isa_and_nonnull(callableOp) && + !isExternalCallable(callableOp)) { + // Add the live callsite. + auto *callsites = getOrCreate(callableOp); + propagateIfChanged(callsites, callsites->join(call)); + } else { + // Mark this call op's predecessors as overdefined. + auto *predecessors = getOrCreate(call); + propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); + } +} + +void ConstantValue::print(raw_ostream &os) const { + if (constant) + return constant.print(os); + os << ""; +} + +static Optional> getOperandValuesImpl( + Operation *op, + function_ref *(Value)> getLattice) { + SmallVector operands; + operands.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + const Lattice *cv = getLattice(operand); + // If any of the operands' values are uninitialized, bail out. + if (cv->isUninitialized()) + return {}; + operands.push_back(cv->getValue().getConstantValue()); + } + return operands; +} + +Optional> +DeadCodeAnalysis::getOperandValues(Operation *op) { + return getOperandValuesImpl(op, [&](Value value) { + auto *lattice = getOrCreate>(value); + lattice->useDefSubscribe(this); + return lattice; + }); +} + +void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { + // Try to deduce a single successor for the branch. + Optional> operands = getOperandValues(branch); + if (!operands) + return; + + if (Block *successor = branch.getSuccessorForOperands(*operands)) { + markEdgeLive(branch->getBlock(), successor); + } else { + // Otherwise, mark all successors as executable and outgoing edges. + for (Block *successor : branch->getSuccessors()) + markEdgeLive(branch->getBlock(), successor); + } +} + +void DeadCodeAnalysis::visitRegionBranchOperation( + RegionBranchOpInterface branch) { + // Try to deduce which regions are executable. + Optional> operands = getOperandValues(branch); + if (!operands) + return; + + SmallVector successors; + branch.getSuccessorRegions(/*index=*/{}, *operands, successors); + + for (const RegionSuccessor &successor : successors) { + // Mark the entry block as executable. + Region *region = successor.getSuccessor(); + assert(region && "expected a region successor"); + auto *state = getOrCreate(®ion->front()); + propagateIfChanged(state, state->setToLive()); + // Add the parent op as a predecessor. + auto *predecessors = getOrCreate(®ion->front()); + propagateIfChanged( + predecessors, + predecessors->join(branch, successor.getSuccessorInputs())); + } +} + +void DeadCodeAnalysis::visitRegionTerminator(Operation *op, + RegionBranchOpInterface branch) { + Optional> operands = getOperandValues(op); + if (!operands) + return; + + SmallVector successors; + branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(), + *operands, successors); + + // Mark successor region entry blocks as executable and add this op to the + // list of predecessors. + for (const RegionSuccessor &successor : successors) { + PredecessorState *predecessors; + if (Region *region = successor.getSuccessor()) { + auto *state = getOrCreate(®ion->front()); + propagateIfChanged(state, state->setToLive()); + predecessors = getOrCreate(®ion->front()); + } else { + // Add this terminator as a predecessor to the parent op. + predecessors = getOrCreate(branch); + } + propagateIfChanged(predecessors, + predecessors->join(op, successor.getSuccessorInputs())); + } +} + +void DeadCodeAnalysis::visitCallableTerminator(Operation *op, + CallableOpInterface callable) { + // If there are no exiting values, we have nothing to do. + if (op->getNumOperands() == 0) + return; + + // Add as predecessors to all callsites this return op. + auto *callsites = getOrCreateFor(op, callable); + bool canResolve = op->hasTrait(); + for (Operation *predecessor : callsites->getKnownPredecessors()) { + assert(isa(predecessor)); + auto *predecessors = getOrCreate(predecessor); + if (canResolve) { + propagateIfChanged(predecessors, predecessors->join(op)); + } else { + // If the terminator is not a return-like, then conservatively assume we + // can't resolve the predecessor. + propagateIfChanged(predecessors, + predecessors->setHasUnknownPredecessors()); + } + } +} + +void Executable::onUpdate(DataFlowSolver *solver) const { + if (auto *block = point.dyn_cast()) { + // Re-invoke the analyses on the block itself. + for (DataFlowAnalysis *analysis : subscribers) + solver->enqueue({block, analysis}); + // Re-invoke the analyses on all operations in the block. + for (DataFlowAnalysis *analysis : subscribers) + for (Operation &op : *block) + solver->enqueue({&op, analysis}); + } else if (auto *programPoint = point.dyn_cast()) { + // Re-invoke the analysis on the successor block. + if (auto *edge = dyn_cast(programPoint)) + for (DataFlowAnalysis *analysis : subscribers) + solver->enqueue({edge->getTo(), analysis}); + } +} + +AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis( + DataFlowSolver &solver) + : DataFlowAnalysis(solver) { + registerPointKind(); +} + +LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) { + // Mark the entry block arguments as having reached their pessimistic + // fixpoints. + for (Region ®ion : top->getRegions()) { + if (region.empty()) + continue; + for (Value argument : region.front().getArguments()) + markPessimisticFixpoint(getLatticeElement(argument)); + } + + return initializeRecursively(top); +} + +LogicalResult +AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) { + // Initialize the analysis by visiting every owner of an SSA value (all + // operations and blocks). + visitOperation(op); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + getOrCreate(&block)->blockContentSubscribe(this); + visitBlock(&block); + for (Operation &op : block) + if (failed(initializeRecursively(&op))) + return failure(); + } + } + + return success(); +} + +LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) { + if (Operation *op = point.dyn_cast()) + visitOperation(op); + else if (Block *block = point.dyn_cast()) + visitBlock(block); + else + return failure(); + return success(); +} + +void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) { + // Exit early on operations with no results. + if (op->getNumResults() == 0) + return; + + // If the containing block is not executable, bail out. + if (!getOrCreate(op->getBlock())->isLive()) + return; + + // Get the result lattices. + SmallVector resultLattices; + resultLattices.reserve(op->getNumResults()); + // Track whether all results have reached their fixpoint. + bool allAtFixpoint = true; + for (Value result : op->getResults()) { + AbstractLattice *resultLattice = getLatticeElement(result); + allAtFixpoint &= resultLattice->isAtFixpoint(); + resultLattices.push_back(resultLattice); + } + // If all result lattices have reached a fixpoint, there is nothing to do. + if (allAtFixpoint) + return; + + // The results of a region branch operation are determined by control-flow. + if (auto branch = dyn_cast(op)) { + return visitRegionSuccessors({branch}, branch, + /*successorIndex=*/llvm::None, resultLattices); + } + + // The results of a call operation are determined by the callgraph. + if (auto call = dyn_cast(op)) { + const auto *predecessors = getOrCreateFor(op, call); + // If not all return sites are known, then conservatively assume we can't + // reason about the data-flow. + if (!predecessors->allPredecessorsKnown()) + return markAllPessimisticFixpoint(resultLattices); + for (Operation *predecessor : predecessors->getKnownPredecessors()) + for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) + join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); + return; + } + + // Grab the lattice elements of the operands. + SmallVector operandLattices; + operandLattices.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + AbstractLattice *operandLattice = getLatticeElement(operand); + operandLattice->useDefSubscribe(this); + // If any of the operand states are not initialized, bail out. + if (operandLattice->isUninitialized()) + return; + operandLattices.push_back(operandLattice); + } + + // Invoke the operation transfer function. + visitOperationImpl(op, operandLattices, resultLattices); +} + +void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) { + // Exit early on blocks with no arguments. + if (block->getNumArguments() == 0) + return; + + // If the block is not executable, bail out. + if (!getOrCreate(block)->isLive()) + return; + + // Get the argument lattices. + SmallVector argLattices; + argLattices.reserve(block->getNumArguments()); + bool allAtFixpoint = true; + for (BlockArgument argument : block->getArguments()) { + AbstractLattice *argLattice = getLatticeElement(argument); + allAtFixpoint &= argLattice->isAtFixpoint(); + argLattices.push_back(argLattice); + } + // If all argument lattices have reached their fixpoints, then there is + // nothing to do. + if (allAtFixpoint) + return; + + // The argument lattices of entry blocks are set by region control-flow or the + // callgraph. + if (block->isEntryBlock()) { + // Check if this block is the entry block of a callable region. + auto callable = dyn_cast(block->getParentOp()); + if (callable && callable.getCallableRegion() == block->getParent()) { + const auto *callsites = getOrCreateFor(block, callable); + // If not all callsites are known, conservatively mark all lattices as + // having reached their pessimistic fixpoints. + if (!callsites->allPredecessorsKnown()) + return markAllPessimisticFixpoint(argLattices); + for (Operation *callsite : callsites->getKnownPredecessors()) { + auto call = cast(callsite); + for (auto it : llvm::zip(call.getArgOperands(), argLattices)) + join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); + } + return; + } + + // Check if the lattices can be determined from region control flow. + if (auto branch = dyn_cast(block->getParentOp())) { + return visitRegionSuccessors( + block, branch, block->getParent()->getRegionNumber(), argLattices); + } + + // Otherwise, we can't reason about the data-flow. + return markAllPessimisticFixpoint(argLattices); + } + + // Iterate over the predecessors of the non-entry block. + for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); + it != e; ++it) { + Block *predecessor = *it; + + // If the edge from the predecessor block to the current block is not live, + // bail out. + auto *edgeExecutable = + getOrCreate(getProgramPoint(predecessor, block)); + edgeExecutable->blockContentSubscribe(this); + if (!edgeExecutable->isLive()) + continue; + + // Check if we can reason about the data-flow from the predecessor. + if (auto branch = + dyn_cast(predecessor->getTerminator())) { + SuccessorOperands operands = + branch.getSuccessorOperands(it.getSuccessorIndex()); + for (auto &it : llvm::enumerate(argLattices)) { + if (Value operand = operands[it.index()]) { + join(it.value(), *getLatticeElementFor(block, operand)); + } else { + // Conservatively mark internally produced arguments as having reached + // their pessimistic fixpoint. + markPessimisticFixpoint(it.value()); + } + } + } else { + return markAllPessimisticFixpoint(argLattices); + } + } +} + +void AbstractSparseDataFlowAnalysis::visitRegionSuccessors( + ProgramPoint point, RegionBranchOpInterface branch, + Optional successorIndex, ArrayRef lattices) { + const auto *predecessors = getOrCreateFor(point, point); + assert(predecessors->allPredecessorsKnown() && + "unexpected unresolved region successors"); + + for (Operation *op : predecessors->getKnownPredecessors()) { + // Get the incoming successor operands. + Optional operands; + Optional predecessorIndex; + + // Check if the predecessor is the parent op. + if (op == branch) { + operands = branch.getSuccessorEntryOperands(*successorIndex); + predecessorIndex = llvm::None; + + // Otherwise, try to deduce the operands from a region return-like op. + } else { + assert(op->hasTrait() && "expected a terminator"); + if (isRegionReturnLike(op)) { + operands = getRegionBranchSuccessorOperands(op, successorIndex); + predecessorIndex = op->getParentRegion()->getRegionNumber(); + } + } + + if (!operands) { + // We can't reason about the data-flow. + return markAllPessimisticFixpoint(lattices); + } + + ValueRange inputs = predecessors->getSuccessorInputs(op); + assert(inputs.size() == operands->size() && + "expected the same number of successor inputs as operands"); + + // TODO: This was updated to be exposed upstream. + unsigned firstIndex = 0; + if (inputs.size() != lattices.size()) { + if (inputs.empty()) { + markAllPessimisticFixpoint(lattices); + return; + } + firstIndex = inputs.front().cast().getArgNumber(); + markAllPessimisticFixpoint(lattices.take_front(firstIndex)); + markAllPessimisticFixpoint( + lattices.drop_front(firstIndex + inputs.size())); + } + + for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) + join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); + } +} + +const AbstractLattice * +AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, + Value value) { + AbstractLattice *state = getLatticeElement(value); + addDependency(state, point); + return state; +} + +void AbstractSparseDataFlowAnalysis::markPessimisticFixpoint( + AbstractLattice *lattice) { + propagateIfChanged(lattice, lattice->markPessimisticFixpoint()); +} + +void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( + ArrayRef lattices) { + for (AbstractLattice *lattice : lattices) { + markPessimisticFixpoint(lattice); + } +} + +void AbstractSparseDataFlowAnalysis::join(AbstractLattice *lhs, + const AbstractLattice &rhs) { + propagateIfChanged(lhs, lhs->join(rhs)); +} + +void SparseConstantPropagation::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); + + // Don't try to simulate the results of a region operation as we can't + // guarantee that folding will be out-of-place. We don't allow in-place + // folds as the desire here is for simulated execution, and not general + // folding. + if (op->getNumRegions()) + return; + + SmallVector constantOperands; + constantOperands.reserve(op->getNumOperands()); + for (auto *operandLattice : operands) + constantOperands.push_back(operandLattice->getValue().getConstantValue()); + + // Save the original operands and attributes just in case the operation + // folds in-place. The constant passed in may not correspond to the real + // runtime value, so in-place updates are not allowed. + SmallVector originalOperands(op->getOperands()); + DictionaryAttr originalAttrs = op->getAttrDictionary(); + + // Simulate the result of folding this operation to a constant. If folding + // fails or was an in-place fold, mark the results as overdefined. + SmallVector foldResults; + foldResults.reserve(op->getNumResults()); + if (failed(op->fold(constantOperands, foldResults))) { + markAllPessimisticFixpoint(results); + return; + } + + // If the folding was in-place, mark the results as overdefined and reset + // the operation. We don't allow in-place folds as the desire here is for + // simulated execution, and not general folding. + if (foldResults.empty()) { + op->setOperands(originalOperands); + op->setAttrs(originalAttrs); + return; + } + + // Merge the fold results into the lattice for this operation. + assert(foldResults.size() == op->getNumResults() && "invalid result size"); + for (const auto it : llvm::zip(results, foldResults)) { + Lattice *lattice = std::get<0>(it); + + // Merge in the result of the fold, either a constant or a value. + OpFoldResult foldResult = std::get<1>(it); + if (Attribute attr = foldResult.dyn_cast()) { + LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); + propagateIfChanged(lattice, + lattice->join(ConstantValue(attr, op->getDialect()))); + } else { + LLVM_DEBUG(llvm::dbgs() + << "Folded to value: " << foldResult.get() << "\n"); + AbstractSparseDataFlowAnalysis::join( + lattice, *getLatticeElement(foldResult.get())); + } + } +} diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -15,151 +15,16 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Analysis/SparseDataFlowAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/Passes.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "sccp" using namespace mlir; -//===----------------------------------------------------------------------===// -// SCCP Analysis -//===----------------------------------------------------------------------===// - -namespace { -struct SCCPLatticeValue { - SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr) - : constant(constant), constantDialect(dialect) {} - - /// The pessimistic state of SCCP is non-constant. - static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) { - return SCCPLatticeValue(); - } - static SCCPLatticeValue getPessimisticValueState(Value value) { - return SCCPLatticeValue(); - } - - /// Equivalence for SCCP only accounts for the constant, not the originating - /// dialect. - bool operator==(const SCCPLatticeValue &rhs) const { - return constant == rhs.constant; - } - - /// To join the state of two values, we simply check for equivalence. - static SCCPLatticeValue join(const SCCPLatticeValue &lhs, - const SCCPLatticeValue &rhs) { - return lhs == rhs ? lhs : SCCPLatticeValue(); - } - - /// The constant attribute value. - Attribute constant; - - /// The dialect the constant originated from. This is not used as part of the - /// key, and is only needed to materialize the held constant if necessary. - Dialect *constantDialect; -}; - -struct SCCPAnalysis : public ForwardDataFlowAnalysis { - using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; - ~SCCPAnalysis() override = default; - - ChangeResult - visitOperation(Operation *op, - ArrayRef *> operands) final { - - LLVM_DEBUG(llvm::dbgs() << "SCCP: Visiting operation: " << *op << "\n"); - - // Don't try to simulate the results of a region operation as we can't - // guarantee that folding will be out-of-place. We don't allow in-place - // folds as the desire here is for simulated execution, and not general - // folding. - if (op->getNumRegions()) - return markAllPessimisticFixpoint(op->getResults()); - - SmallVector constantOperands( - llvm::map_range(operands, [](LatticeElement *value) { - return value->getValue().constant; - })); - - // Save the original operands and attributes just in case the operation - // folds in-place. The constant passed in may not correspond to the real - // runtime value, so in-place updates are not allowed. - SmallVector originalOperands(op->getOperands()); - DictionaryAttr originalAttrs = op->getAttrDictionary(); - - // Simulate the result of folding this operation to a constant. If folding - // fails or was an in-place fold, mark the results as overdefined. - SmallVector foldResults; - foldResults.reserve(op->getNumResults()); - if (failed(op->fold(constantOperands, foldResults))) - return markAllPessimisticFixpoint(op->getResults()); - - // If the folding was in-place, mark the results as overdefined and reset - // the operation. We don't allow in-place folds as the desire here is for - // simulated execution, and not general folding. - if (foldResults.empty()) { - op->setOperands(originalOperands); - op->setAttrs(originalAttrs); - return markAllPessimisticFixpoint(op->getResults()); - } - - // Merge the fold results into the lattice for this operation. - assert(foldResults.size() == op->getNumResults() && "invalid result size"); - Dialect *dialect = op->getDialect(); - ChangeResult result = ChangeResult::NoChange; - for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { - LatticeElement &lattice = - getLatticeElement(op->getResult(i)); - - // Merge in the result of the fold, either a constant or a value. - OpFoldResult foldResult = foldResults[i]; - if (Attribute attr = foldResult.dyn_cast()) - result |= lattice.join(SCCPLatticeValue(attr, dialect)); - else - result |= lattice.join(getLatticeElement(foldResult.get())); - } - return result; - } - - /// Implementation of `getSuccessorsForOperands` that uses constant operands - /// to potentially remove dead successors. - LogicalResult getSuccessorsForOperands( - BranchOpInterface branch, - ArrayRef *> operands, - SmallVectorImpl &successors) final { - SmallVector constantOperands( - llvm::map_range(operands, [](LatticeElement *value) { - return value->getValue().constant; - })); - if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { - successors.push_back(singleSucc); - return success(); - } - return failure(); - } - - /// Implementation of `getSuccessorsForOperands` that uses constant operands - /// to potentially remove dead region successors. - void getSuccessorsForOperands( - RegionBranchOpInterface branch, Optional sourceIndex, - ArrayRef *> operands, - SmallVectorImpl &successors) final { - SmallVector constantOperands( - llvm::map_range(operands, [](LatticeElement *value) { - return value->getValue().constant; - })); - branch.getSuccessorRegions(sourceIndex, constantOperands, successors); - } -}; -} // namespace - //===----------------------------------------------------------------------===// // SCCP Rewrites //===----------------------------------------------------------------------===// @@ -167,21 +32,21 @@ /// Replace the given value with a constant if the corresponding lattice /// represents a constant. Returns success if the value was replaced, failure /// otherwise. -static LogicalResult replaceWithConstant(SCCPAnalysis &analysis, +static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &builder, OperationFolder &folder, Value value) { - LatticeElement *lattice = - analysis.lookupLatticeElement(value); + auto *lattice = solver.lookupState>(value); if (!lattice) return failure(); - SCCPLatticeValue &latticeValue = lattice->getValue(); - if (!latticeValue.constant) + const ConstantValue &latticeValue = lattice->getValue(); + if (!latticeValue.getConstantValue()) return failure(); // Attempt to materialize a constant for the given value. - Dialect *dialect = latticeValue.constantDialect; - Value constant = folder.getOrCreateConstant( - builder, dialect, latticeValue.constant, value.getType(), value.getLoc()); + Dialect *dialect = latticeValue.getConstantDialect(); + Value constant = folder.getOrCreateConstant(builder, dialect, + latticeValue.getConstantValue(), + value.getType(), value.getLoc()); if (!constant) return failure(); @@ -192,7 +57,7 @@ /// Rewrite the given regions using the computing analysis. This replaces the /// uses of all values that have been computed to be constant, and erases as /// many newly dead operations. -static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, +static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef initialRegions) { SmallVector worklist; auto addToWorklist = [&](MutableArrayRef regions) { @@ -216,7 +81,7 @@ bool replacedAll = op.getNumResults() != 0; for (Value res : op.getResults()) replacedAll &= - succeeded(replaceWithConstant(analysis, builder, folder, res)); + succeeded(replaceWithConstant(solver, builder, folder, res)); // If all of the results of the operation were replaced, try to erase // the operation completely. @@ -233,7 +98,7 @@ // Replace any block arguments with constants. builder.setInsertionPointToStart(block); for (BlockArgument arg : block->getArguments()) - (void)replaceWithConstant(analysis, builder, folder, arg); + (void)replaceWithConstant(solver, builder, folder, arg); } } @@ -250,9 +115,14 @@ void SCCP::runOnOperation() { Operation *op = getOperation(); - SCCPAnalysis analysis(op->getContext()); - analysis.run(op); - rewrite(analysis, op->getContext(), op->getRegions()); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) { + op->emitError("SCCP analysis failed\n"); + return signalPassFailure(); + } + rewrite(solver, op->getContext(), op->getRegions()); } std::unique_ptr mlir::createSCCPPass() { diff --git a/mlir/test/Analysis/test-dead-code-analysis.mlir b/mlir/test/Analysis/test-dead-code-analysis.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-dead-code-analysis.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt -test-dead-code-analysis 2>&1 %s | FileCheck %s + +// CHECK: test_cfg: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: ^bb1 = live +// CHECK: from ^bb1 = live +// CHECK: from ^bb0 = live +// CHECK: ^bb2 = live +// CHECK: from ^bb1 = live +func.func @test_cfg(%cond: i1) -> () + attributes {tag = "test_cfg"} { + cf.br ^bb1 + +^bb1: + cf.cond_br %cond, ^bb1, ^bb2 + +^bb2: + return +} + +func.func @test_region_control_flow(%cond: i1, %arg0: i64, %arg1: i64) -> () { + // CHECK: test_if: + // CHECK: region #0 + // CHECK: region_preds: (all) predecessors: + // CHECK: scf.if + // CHECK: region #1 + // CHECK: region_preds: (all) predecessors: + // CHECK: scf.if + // CHECK: op_preds: (all) predecessors: + // CHECK: scf.yield {then} + // CHECK: scf.yield {else} + scf.if %cond { + scf.yield {then} + } else { + scf.yield {else} + } {tag = "test_if"} + + // test_while: + // region #0 + // region_preds: (all) predecessors: + // scf.while + // scf.yield + // region #1 + // region_preds: (all) predecessors: + // scf.condition + // op_preds: (all) predecessors: + // scf.condition + %c2_i64 = arith.constant 2 : i64 + %0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) { + %1 = arith.cmpi slt, %arg2, %arg1 : i64 + scf.condition(%1) %arg2, %arg2 : i64, i64 + } do { + ^bb0(%arg2: i64, %arg3: i64): + %1 = arith.muli %arg3, %c2_i64 : i64 + scf.yield %1 : i64 + } attributes {tag = "test_while"} + + return +} + +// CHECK: foo: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: op_preds: (all) predecessors: +// CHECK: func.call @foo(%{{.*}}) {tag = "a"} +// CHECK: func.call @foo(%{{.*}}) {tag = "b"} +func.func private @foo(%arg0: i32) -> i32 + attributes {tag = "foo"} { + return {a} %arg0 : i32 +} + +// CHECK: bar: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: op_preds: predecessors: +// CHECK: func.call @bar(%{{.*}}) {tag = "c"} +func.func @bar(%cond: i1) -> i32 + attributes {tag = "bar"} { + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + %c0 = arith.constant 0 : i32 + return {b} %c0 : i32 + +^bb2: + %c1 = arith.constant 1 : i32 + return {c} %c1 : i32 +} + +// CHECK: baz +// CHECK: op_preds: (all) predecessors: +func.func private @baz(i32) -> i32 attributes {tag = "baz"} + +func.func @test_callgraph(%cond: i1, %arg0: i32) -> i32 { + // CHECK: a: + // CHECK: op_preds: (all) predecessors: + // CHECK: func.return {a} + %0 = func.call @foo(%arg0) {tag = "a"} : (i32) -> i32 + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // CHECK: b: + // CHECK: op_preds: (all) predecessors: + // CHECK: func.return {a} + %1 = func.call @foo(%arg0) {tag = "b"} : (i32) -> i32 + return %1 : i32 + +^bb2: + // CHECK: c: + // CHECK: op_preds: (all) predecessors: + // CHECK: func.return {b} + // CHECK: func.return {c} + %2 = func.call @bar(%cond) {tag = "c"} : (i1) -> i32 + // CHECK: d: + // CHECK: op_preds: predecessors: + %3 = func.call @baz(%arg0) {tag = "d"} : (i32) -> i32 + return %2 : i32 +} + +// CHECK: test_unknown_branch: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: ^bb1 = live +// CHECK: from ^bb0 = live +// CHECK: ^bb2 = live +// CHECK: from ^bb0 = live +func.func @test_unknown_branch() -> () + attributes {tag = "test_unknown_branch"} { + "test.unknown_br"() [^bb1, ^bb2] : () -> () + +^bb1: + return + +^bb2: + return +} + +// CHECK: test_unknown_region: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: region #1 +// CHECK: ^bb0 = live +func.func @test_unknown_region() -> () { + "test.unknown_region_br"() ({ + ^bb0: + "test.unknown_region_end"() : () -> () + }, { + ^bb0: + "test.unknown_region_end"() : () -> () + }) {tag = "test_unknown_region"} : () -> () + return +} + +// CHECK: test_known_dead_block: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: ^bb1 = live +// CHECK: ^bb2 = dead +func.func @test_known_dead_block() -> () + attributes {tag = "test_known_dead_block"} { + %true = arith.constant true + cf.cond_br %true, ^bb1, ^bb2 + +^bb1: + return + +^bb2: + return +} + +// CHECK: test_known_dead_edge: +// CHECK: ^bb2 = live +// CHECK: from ^bb1 = dead +// CHECK: from ^bb0 = live +func.func @test_known_dead_edge(%arg0: i1) -> () + attributes {tag = "test_known_dead_edge"} { + cf.cond_br %arg0, ^bb1, ^bb2 + +^bb1: + %true = arith.constant true + cf.cond_br %true, ^bb3, ^bb2 + +^bb2: + return + +^bb3: + return +} + +func.func @test_known_region_predecessors() -> () { + %false = arith.constant false + // CHECK: test_known_if: + // CHECK: region #0 + // CHECK: ^bb0 = dead + // CHECK: region #1 + // CHECK: ^bb0 = live + // CHECK: region_preds: (all) predecessors: + // CHECK: scf.if + // CHECK: op_preds: (all) predecessors: + // CHECK: scf.yield {else} + scf.if %false { + scf.yield {then} + } else { + scf.yield {else} + } {tag = "test_known_if"} + return +} + +// CHECK: callable: +// CHECK: region #0 +// CHECK: ^bb0 = live +// CHECK: op_preds: predecessors: +// CHECK: func.call @callable() {then} +func.func @callable() attributes {tag = "callable"} { + return +} + +func.func @test_dead_callsite() -> () { + %true = arith.constant true + scf.if %true { + func.call @callable() {then} : () -> () + scf.yield + } else { + func.call @callable() {else} : () -> () + scf.yield + } + return +} + +func.func private @test_dead_return(%arg0: i32) -> i32 { + %true = arith.constant true + cf.cond_br %true, ^bb1, ^bb1 + +^bb1: + return {true} %arg0 : i32 + +^bb2: + return {false} %arg0 : i32 +} + +func.func @test_call_dead_return(%arg0: i32) -> () { + // CHECK: test_dead_return: + // CHECK: op_preds: (all) predecessors: + // CHECK: func.return {true} + %0 = func.call @test_dead_return(%arg0) {tag = "test_dead_return"} : (i32) -> i32 + return +} diff --git a/mlir/test/Analysis/test-last-modified-callgraph.mlir b/mlir/test/Analysis/test-last-modified-callgraph.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-last-modified-callgraph.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s + +// CHECK-LABEL: test_tag: test_callsite +// CHECK: operand #0 +// CHECK-NEXT: - a +func.func private @single_callsite_fn(%ptr: memref) -> memref { + return {tag = "test_callsite"} %ptr : memref +} + +func.func @test_callsite() { + %ptr = memref.alloc() : memref + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "a"} : memref + %0 = func.call @single_callsite_fn(%ptr) : (memref) -> memref + return +} + +// CHECK-LABEL: test_tag: test_return_site +// CHECK: operand #0 +// CHECK-NEXT: - b +func.func private @single_return_site_fn(%ptr: memref) -> memref { + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "b"} : memref + return %ptr : memref +} + +// CHECK-LABEL: test_tag: test_multiple_callsites +// CHECK: operand #0 +// CHECK-NEXT: write0 +// CHECK-NEXT: write1 +func.func @test_return_site(%ptr: memref) -> memref { + %0 = func.call @single_return_site_fn(%ptr) : (memref) -> memref + return {tag = "test_return_site"} %0 : memref +} + +func.func private @multiple_callsite_fn(%ptr: memref) -> memref { + return {tag = "test_multiple_callsites"} %ptr : memref +} + +func.func @test_multiple_callsites(%a: i32, %ptr: memref) -> memref { + memref.store %a, %ptr[] {tag_name = "write0"} : memref + %0 = func.call @multiple_callsite_fn(%ptr) : (memref) -> memref + memref.store %a, %ptr[] {tag_name = "write1"} : memref + %1 = func.call @multiple_callsite_fn(%ptr) : (memref) -> memref + return %ptr : memref +} + +// CHECK-LABEL: test_tag: test_multiple_return_sites +// CHECK: operand #0 +// CHECK-NEXT: return0 +// CHECK-NEXT: return1 +func.func private @multiple_return_site_fn(%cond: i1, %a: i32, %ptr: memref) -> memref { + cf.cond_br %cond, ^a, ^b + +^a: + memref.store %a, %ptr[] {tag_name = "return0"} : memref + return %ptr : memref + +^b: + memref.store %a, %ptr[] {tag_name = "return1"} : memref + return %ptr : memref +} + +func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref) -> memref { + %0 = func.call @multiple_return_site_fn(%cond, %a, %ptr) : (i1, i32, memref) -> memref + return {tag = "test_multiple_return_sites"} %0 : memref +} \ No newline at end of file diff --git a/mlir/test/Analysis/test-last-modified.mlir b/mlir/test/Analysis/test-last-modified.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-last-modified.mlir @@ -0,0 +1,115 @@ +// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s + +// CHECK-LABEL: test_tag: test_simple_mod +// CHECK: operand #0 +// CHECK-NEXT: - a +// CHECK: operand #1 +// CHECK-NEXT: - b +func.func @test_simple_mod(%arg0: memref, %arg1: memref) -> (memref, memref) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + memref.store %c0, %arg0[] {tag_name = "a"} : memref + memref.store %c1, %arg1[] {tag_name = "b"} : memref + return {tag = "test_simple_mod"} %arg0, %arg1 : memref, memref +} + +// CHECK-LABEL: test_tag: test_simple_mod_overwrite_a +// CHECK: operand #1 +// CHECK-NEXT: - a +// CHECK-LABEL: test_tag: test_simple_mod_overwrite_b +// CHECK: operand #0 +// CHECK-NEXT: - b +func.func @test_simple_mod_overwrite(%arg0: memref) -> memref { + %c0 = arith.constant 0 : i32 + memref.store %c0, %arg0[] {tag = "test_simple_mod_overwrite_a", tag_name = "a"} : memref + %c1 = arith.constant 1 : i32 + memref.store %c1, %arg0[] {tag_name = "b"} : memref + return {tag = "test_simple_mod_overwrite_b"} %arg0 : memref +} + +// CHECK-LABEL: test_tag: test_mod_control_flow +// CHECK: operand #0 +// CHECK-NEXT: - b +// CHECK-NEXT: - a +func.func @test_mod_control_flow(%cond: i1, %ptr: memref) -> memref { + cf.cond_br %cond, ^a, ^b + +^a: + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "a"} : memref + cf.br ^c + +^b: + %c1 = arith.constant 1 : i32 + memref.store %c1, %ptr[] {tag_name = "b"} : memref + cf.br ^c + +^c: + return {tag = "test_mod_control_flow"} %ptr : memref +} + +// CHECK-LABEL: test_tag: test_mod_dead_branch +// CHECK: operand #0 +// CHECK-NEXT: - a +func.func @test_mod_dead_branch(%arg: i32, %ptr: memref) -> memref { + %0 = arith.subi %arg, %arg : i32 + %1 = arith.constant -1 : i32 + %2 = arith.cmpi sgt, %0, %1 : i32 + cf.cond_br %2, ^a, ^b + +^a: + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "a"} : memref + cf.br ^c + +^b: + %c1 = arith.constant 1 : i32 + memref.store %c1, %ptr[] {tag_name = "b"} : memref + cf.br ^c + +^c: + return {tag = "test_mod_dead_branch"} %ptr : memref +} + +// CHECK-LABEL: test_tag: test_mod_region_control_flow +// CHECK: operand #0 +// CHECK-NEXT: then +// CHECK-NEXT: else +func.func @test_mod_region_control_flow(%cond: i1, %ptr: memref) -> memref { + scf.if %cond { + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "then"}: memref + } else { + %c1 = arith.constant 1 : i32 + memref.store %c1, %ptr[] {tag_name = "else"} : memref + } + return {tag = "test_mod_region_control_flow"} %ptr : memref +} + +// CHECK-LABEL: test_tag: test_mod_dead_region +// CHECK: operand #0 +// CHECK-NEXT: else +func.func @test_mod_dead_region(%ptr: memref) -> memref { + %false = arith.constant false + scf.if %false { + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "then"}: memref + } else { + %c1 = arith.constant 1 : i32 + memref.store %c1, %ptr[] {tag_name = "else"} : memref + } + return {tag = "test_mod_dead_region"} %ptr : memref +} + +// CHECK-LABEL: test_tag: unknown_memory_effects_a +// CHECK: operand #1 +// CHECK-NEXT: - a +// CHECK-LABEL: test_tag: unknown_memory_effects_b +// CHECK: operand #0 +// CHECK-NEXT: - +func.func @unknown_memory_effects(%ptr: memref) -> memref { + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag = "unknown_memory_effects_a", tag_name = "a"} : memref + "test.unknown_effects"() : () -> () + return {tag = "unknown_memory_effects_b"} %ptr : memref +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -4,6 +4,8 @@ TestCallGraph.cpp TestDataFlow.cpp TestDataFlowFramework.cpp + TestDeadCodeAnalysis.cpp + TestDenseDataFlowAnalysis.cpp TestLiveness.cpp TestMatchReduction.cpp TestMemRefBoundCheck.cpp diff --git a/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp @@ -0,0 +1,116 @@ +//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SparseDataFlowAnalysis.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +/// Print the liveness of every block, control-flow edge, and the predecessors +/// of all regions, callables, and calls. +static void printAnalysisResults(DataFlowSolver &solver, Operation *op, + raw_ostream &os) { + op->walk([&](Operation *op) { + auto tag = op->getAttrOfType("tag"); + if (!tag) + return; + os << tag.getValue() << ":\n"; + for (Region ®ion : op->getRegions()) { + os << " region #" << region.getRegionNumber() << "\n"; + for (Block &block : region) { + os << " "; + block.printAsOperand(os); + os << " = "; + auto *live = solver.lookupState(&block); + if (live) + os << *live; + else + os << "dead"; + os << "\n"; + for (Block *pred : block.getPredecessors()) { + os << " from "; + pred->printAsOperand(os); + os << " = "; + auto *live = solver.lookupState( + solver.getProgramPoint(pred, &block)); + if (live) + os << *live; + else + os << "dead"; + os << "\n"; + } + } + if (!region.empty()) { + auto *preds = solver.lookupState(®ion.front()); + if (preds) + os << "region_preds: " << *preds << "\n"; + } + } + auto *preds = solver.lookupState(op); + if (preds) + os << "op_preds: " << *preds << "\n"; + }); +} + +namespace { +/// This is a simple analysis that implements a transfer function for constant +/// operations. +struct ConstantAnalysis : public DataFlowAnalysis { + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { + if (op->hasTrait()) + if (failed(visit(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + + LogicalResult visit(ProgramPoint point) override { + Operation *op = point.get(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>(op->getResult(0)); + propagateIfChanged( + constant, constant->join(ConstantValue(value, op->getDialect()))); + } + return success(); + } +}; + +/// This is a simple pass that runs dead code analysis with no constant value +/// provider. It marks everything as live. +struct TestDeadCodeAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass) + + StringRef getArgument() const override { return "test-dead-code-analysis"; } + + void runOnOperation() override { + Operation *op = getOperation(); + + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + printAnalysisResults(solver, op, llvm::errs()); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestDeadCodeAnalysisPass() { + PassRegistration(); +} +} // end namespace test +} // end namespace mlir diff --git a/mlir/test/lib/Analysis/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/TestDenseDataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestDenseDataFlowAnalysis.cpp @@ -0,0 +1,274 @@ +//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DenseDataFlowAnalysis.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This lattice represents a single underlying value for an SSA value. +class UnderlyingValue { +public: + /// The pessimistic underlying value of a value is itself. + static UnderlyingValue getPessimisticValueState(Value value) { + return {value}; + } + + /// Create an underlying value state with a known underlying value. + UnderlyingValue(Value underlyingValue = {}) + : underlyingValue(underlyingValue) {} + + /// Returns the underlying value. + Value getUnderlyingValue() const { return underlyingValue; } + + /// Join two underlying values. If there are conflicting underlying values, + /// go to the pessimistic value. + static UnderlyingValue join(const UnderlyingValue &lhs, + const UnderlyingValue &rhs) { + return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue(); + } + + /// Compare underlying values. + bool operator==(const UnderlyingValue &rhs) const { + return underlyingValue == rhs.underlyingValue; + } + + void print(raw_ostream &os) const { os << underlyingValue; } + +private: + Value underlyingValue; +}; + +/// This lattice represents, for a given memory resource, the potential last +/// operations that modified the resource. +class LastModification : public AbstractDenseLattice { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) + + using AbstractDenseLattice::AbstractDenseLattice; + + /// The lattice is always initialized. + bool isUninitialized() const override { return false; } + + /// Initialize the lattice. Does nothing. + ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } + + /// Mark the lattice as having reached its pessimistic fixpoint. That is, the + /// last modifications of all memory resources are unknown. + ChangeResult reset() override { + if (lastMods.empty()) + return ChangeResult::NoChange; + lastMods.clear(); + return ChangeResult::Change; + } + + /// The lattice is never at a fixpoint. + bool isAtFixpoint() const override { return false; } + + /// Join the last modifications. + ChangeResult join(const AbstractDenseLattice &lattice) override { + const auto &rhs = static_cast(lattice); + ChangeResult result = ChangeResult::NoChange; + for (const auto &mod : rhs.lastMods) { + auto &lhsMod = lastMods[mod.first]; + if (lhsMod != mod.second) { + lhsMod.insert(mod.second.begin(), mod.second.end()); + result |= ChangeResult::Change; + } + } + return result; + } + + /// Set the last modification of a value. + ChangeResult set(Value value, Operation *op) { + auto &lastMod = lastMods[value]; + ChangeResult result = ChangeResult::NoChange; + if (lastMod.size() != 1 || *lastMod.begin() != op) { + result = ChangeResult::Change; + lastMod.clear(); + lastMod.insert(op); + } + return result; + } + + /// Get the last modifications of a value. Returns none if the last + /// modifications are not known. + Optional> getLastModifiers(Value value) const { + auto it = lastMods.find(value); + if (it == lastMods.end()) + return {}; + return it->second.getArrayRef(); + } + + void print(raw_ostream &os) const override { + for (const auto &lastMod : lastMods) { + os << lastMod.first << ":\n"; + for (Operation *op : lastMod.second) + os << " " << *op << "\n"; + } + } + +private: + /// The potential last modifications of a memory resource. Use a set vector to + /// keep the results deterministic. + DenseMap, + SmallPtrSet>> + lastMods; +}; + +class LastModifiedAnalysis : public DenseDataFlowAnalysis { +public: + using DenseDataFlowAnalysis::DenseDataFlowAnalysis; + + /// Visit an operation. If the operation has no memory effects, then the state + /// is propagated with no change. If the operation allocates a resource, then + /// its reaching definitions is set to empty. If the operation writes to a + /// resource, then its reaching definition is set to the written value. + void visitOperation(Operation *op, const LastModification &before, + LastModification *after) override; +}; + +/// Define the lattice class explicitly to provide a type ID. +struct UnderlyingValueLattice : public Lattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice) + using Lattice::Lattice; +}; + +/// An analysis that uses forwarding of values along control-flow and callgraph +/// edges to determine single underlying values for block arguments. This +/// analysis exists so that the test analysis and pass can test the behaviour of +/// the dense data-flow analysis on the callgraph. +class UnderlyingValueAnalysis + : public SparseDataFlowAnalysis { +public: + using SparseDataFlowAnalysis::SparseDataFlowAnalysis; + + /// The underlying value of the results of an operation are not known. + void visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override { + markAllPessimisticFixpoint(results); + } +}; +} // end anonymous namespace + +/// Look for the most underlying value of a value. +static Value getMostUnderlyingValue( + Value value, + function_ref getUnderlyingValueFn) { + const UnderlyingValueLattice *underlying; + do { + underlying = getUnderlyingValueFn(value); + if (!underlying || underlying->isUninitialized()) + return {}; + Value underlyingValue = underlying->getValue().getUnderlyingValue(); + if (underlyingValue == value) + break; + value = underlyingValue; + } while (true); + return value; +} + +void LastModifiedAnalysis::visitOperation(Operation *op, + const LastModification &before, + LastModification *after) { + auto memory = dyn_cast(op); + // If we can't reason about the memory effects, then conservatively assume we + // can't deduce anything about the last modifications. + if (!memory) + return reset(after); + + SmallVector effects; + memory.getEffects(effects); + + ChangeResult result = after->join(before); + for (const auto &effect : effects) { + Value value = effect.getValue(); + + // If we see an effect on anything other than a value, assume we can't + // deduce anything about the last modifications. + if (!value) + return reset(after); + + value = getMostUnderlyingValue(value, [&](Value value) { + return getOrCreateFor(op, value); + }); + if (!value) + return; + + // Nothing to do for reads. + if (isa(effect.getEffect())) + continue; + + result |= after->set(value, op); + } + propagateIfChanged(after, result); +} + +namespace { +struct TestLastModifiedPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass) + + StringRef getArgument() const override { return "test-last-modified"; } + + void runOnOperation() override { + Operation *op = getOperation(); + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + raw_ostream &os = llvm::errs(); + + op->walk([&](Operation *op) { + auto tag = op->getAttrOfType("tag"); + if (!tag) + return; + os << "test_tag: " << tag.getValue() << ":\n"; + const LastModification *lastMods = + solver.lookupState(op); + assert(lastMods && "expected a dense lattice"); + for (auto &it : llvm::enumerate(op->getOperands())) { + os << " operand #" << it.index() << "\n"; + Value value = getMostUnderlyingValue(it.value(), [&](Value value) { + return solver.lookupState(value); + }); + assert(value && "expected an underlying value"); + if (Optional> lastMod = + lastMods->getLastModifiers(value)) { + for (Operation *lastModifier : *lastMod) { + if (auto tagName = + lastModifier->getAttrOfType("tag_name")) { + os << " - " << tagName.getValue() << "\n"; + } else { + os << " - " << lastModifier->getName() << "\n"; + } + } + } else { + os << " - \n"; + } + } + }); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestLastModifiedPass() { + PassRegistration(); +} +} // end namespace test +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -72,6 +72,7 @@ void registerTestGpuSerializeToHsacoPass(); void registerTestDataFlowPass(); void registerTestDataLayoutQuery(); +void registerTestDeadCodeAnalysisPass(); void registerTestDecomposeCallGraphTypes(); void registerTestDiagnosticsPass(); void registerTestDominancePass(); @@ -84,6 +85,7 @@ void registerTestGenericIRVisitorsPass(); void registerTestGenericIRVisitorsInterruptPass(); void registerTestInterfaces(); +void registerTestLastModifiedPass(); void registerTestLinalgCodegenStrategy(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgFusionTransforms(); @@ -171,6 +173,7 @@ mlir::test::registerTestDecomposeCallGraphTypes(); mlir::test::registerTestDataFlowPass(); mlir::test::registerTestDataLayoutQuery(); + mlir::test::registerTestDeadCodeAnalysisPass(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); mlir::test::registerTestExpandTanhPass(); @@ -180,6 +183,7 @@ mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces(); + mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgCodegenStrategy(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgFusionTransforms();