diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h @@ -29,7 +29,7 @@ class ConstantValue { public: /// The pessimistic value state of the constant value is unknown. - static ConstantValue getPessimisticValueState(Value value) { return {}; } + static ConstantValue getPessimisticValue(Value value) { return {}; } /// Construct a constant value with a known constant. ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr) @@ -53,6 +53,17 @@ return lhs == rhs ? lhs : ConstantValue(); } + static ConstantValue meet(const ConstantValue &lhs, + const ConstantValue &rhs) { + if (lhs == rhs) + return lhs; + if (!lhs.constant) + return rhs; + if (!rhs.constant) + return lhs; + return ConstantValue(); + } + /// Print the constant value. void print(raw_ostream &os) const; @@ -63,6 +74,12 @@ Dialect *dialect; }; +class ConstantValueState : public OptimisticSparseState { +public: + using OptimisticSparseState::OptimisticSparseState; + using ElementT = SparseElement; +}; + //===----------------------------------------------------------------------===// // SparseConstantPropagation //===----------------------------------------------------------------------===// @@ -72,13 +89,13 @@ /// operands, by speculatively folding operations. When combined with dead-code /// analysis, this becomes sparse conditional constant propagation (SCCP). class SparseConstantPropagation - : public SparseDataFlowAnalysis> { + : public SparseDataFlowAnalysis { public: using SparseDataFlowAnalysis::SparseDataFlowAnalysis; - void visitOperation(Operation *op, - ArrayRef *> operands, - ArrayRef *> results) override; + void + visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; }; } // end namespace dataflow diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -24,21 +24,77 @@ namespace mlir { namespace dataflow { +//===----------------------------------------------------------------------===// +// CFGEdge +//===----------------------------------------------------------------------===// + +/// This program point represents a control-flow edge between a block and one +/// of its successors. +class CFGEdge + : public GenericProgramPointBase> { +public: + using Base::Base; + + /// Get the block from which the edge originates. + Block *getFrom() const { return getValue().first; } + /// Get the target block. + Block *getTo() const { return getValue().second; } + + /// Print the blocks between the control-flow edge. + void print(raw_ostream &os) const override; + /// Get a fused location of both blocks. + Location getLoc() const override; +}; + //===----------------------------------------------------------------------===// // Executable //===----------------------------------------------------------------------===// /// This is a simple analysis state that represents whether the associated /// program point (either a block or a control-flow edge) is live. -class Executable : public AnalysisState { +class Executable : public AbstractState { public: - using AnalysisState::AnalysisState; - - /// The state is initialized by default. - bool isUninitialized() const override { return false; } - - /// The state is always initialized. - ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } + template + class Element : public SingleStateElement { + public: + using SingleStateElement::SingleStateElement; + + /// When the state of the program point is changed to live, re-invoke + /// subscribed analyses on the operations in the block and on the block + /// itself. + void onUpdate() override { + if (auto *block = this->point.template dyn_cast()) { + // Re-invoke the analyses on the block itself. + for (DataFlowAnalysis *analysis : subscribers) + this->solver.enqueue({block, analysis}); + // Re-invoke the analyses on all operations in the block. + for (DataFlowAnalysis *analysis : subscribers) + for (Operation &op : *block) + this->solver.enqueue({&op, analysis}); + } else if (auto *programPoint = + this->point.template dyn_cast()) { + // Re-invoke the analysis on the successor block. + if (auto *edge = dyn_cast(programPoint)) + for (DataFlowAnalysis *analysis : subscribers) + this->solver.enqueue({edge->getTo(), analysis}); + } + } + + /// Subscribe an analysis to changes to the liveness. + void blockContentSubscribe(DataFlowAnalysis *analysis) { + subscribers.insert(analysis); + } + + private: + /// A set of analyses that should be updated when this state changes. + SetVector, + SmallPtrSet> + subscribers; + }; + using ElementT = Element; + + /// Optimistically assume the program point is dead. + explicit Executable(ProgramPoint point) : live(false) {} /// Set the state of the program point to live. ChangeResult setToLive(); @@ -46,28 +102,12 @@ /// Get whether the program point is live. bool isLive() const { return live; } - /// Print the liveness. + /// Print the liveness; void print(raw_ostream &os) const override; - /// When the state of the program point is changed to live, re-invoke - /// subscribed analyses on the operations in the block and on the block - /// itself. - void onUpdate(DataFlowSolver *solver) const override; - - /// Subscribe an analysis to changes to the liveness. - void blockContentSubscribe(DataFlowAnalysis *analysis) { - subscribers.insert(analysis); - } - private: - /// Whether the program point is live. Optimistically assume that the program - /// point is dead. - bool live = false; - - /// A set of analyses that should be updated when this state changes. - SetVector, - SmallPtrSet> - subscribers; + /// Whether the program point is live. + bool live; }; //===----------------------------------------------------------------------===// @@ -90,15 +130,11 @@ /// /// The state can indicate that it is underdefined, meaning that not all live /// control-flow predecessors can be known. -class PredecessorState : public AnalysisState { +class PredecessorState : public AbstractState { public: - using AnalysisState::AnalysisState; + using ElementT = SingleStateElement; - /// The state is initialized by default. - bool isUninitialized() const override { return false; } - - /// The state is always initialized. - ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } + explicit PredecessorState(ProgramPoint point) {} /// Print the known predecessors. void print(raw_ostream &os) const override; @@ -142,28 +178,6 @@ DenseMap successorInputs; }; -//===----------------------------------------------------------------------===// -// CFGEdge -//===----------------------------------------------------------------------===// - -/// This program point represents a control-flow edge between a block and one -/// of its successors. -class CFGEdge - : public GenericProgramPointBase> { -public: - using Base::Base; - - /// Get the block from which the edge originates. - Block *getFrom() const { return getValue().first; } - /// Get the target block. - Block *getTo() const { return getValue().second; } - - /// Print the blocks between the control-flow edge. - void print(raw_ostream &os) const override; - /// Get a fused location of both blocks. - Location getLoc() const override; -}; - //===----------------------------------------------------------------------===// // DeadCodeAnalysis //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -16,33 +16,26 @@ #define MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" namespace mlir { namespace dataflow { //===----------------------------------------------------------------------===// -// AbstractDenseLattice +// AbstractDenseState //===----------------------------------------------------------------------===// /// This class represents a dense lattice. A dense lattice is attached to /// operations to represent the program state after their execution or to blocks /// to represent the program state at the beginning of the block. A dense /// lattice is propagated through the IR by dense data-flow analysis. -class AbstractDenseLattice : public AnalysisState { -public: - /// A dense lattice can only be created for operations and blocks. - using AnalysisState::AnalysisState; - - /// Join the lattice across control-flow or callgraph edges. - virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0; +using AbstractDenseState = AbstractSparseState; - /// Reset the dense lattice to a pessimistic value. This occurs when the - /// analysis cannot reason about the data-flow. - virtual ChangeResult reset() = 0; +class AbstractDenseElement : public AbstractElement { +public: + using AbstractElement::AbstractElement; - /// Returns true if the lattice state has reached a pessimistic fixpoint. That - /// is, no further modifications to the lattice can occur. - virtual bool isAtFixpoint() const = 0; + virtual const AbstractDenseState *get() const override = 0; }; //===----------------------------------------------------------------------===// @@ -78,26 +71,29 @@ /// Propagate the dense lattice before the execution of an operation to the /// lattice after its execution. virtual void visitOperationImpl(Operation *op, - const AbstractDenseLattice &before, - AbstractDenseLattice *after) = 0; + const AbstractDenseState &before, + AbstractDenseElement *after) = 0; - /// Get the dense lattice after the execution of the given program point. - virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0; + /// Get the dense element after the execution of the given program point. + virtual AbstractDenseElement *getLattice(ProgramPoint point) = 0; /// Get the dense lattice after the execution of the given program point and /// add it as a dependency to a program point. - const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent, - ProgramPoint point); - - /// Mark the dense lattice as having reached its pessimistic fixpoint and - /// propagate an update if it changed. - void reset(AbstractDenseLattice *lattice) { - propagateIfChanged(lattice, lattice->reset()); + const AbstractDenseState *getLatticeFor(ProgramPoint dependent, + ProgramPoint point); + + void update(AbstractDenseElement *element, + function_ref updateFn) { + element->update(this, [updateFn](AbstractState *state) { + return updateFn(static_cast(state)); + }); } - /// Join a lattice with another and propagate an update if it changed. - void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) { - propagateIfChanged(lhs, lhs->join(rhs)); + void markPessimisticFixpoint(AbstractDenseElement *element) { + element->update(this, [](AbstractState *state) { + return static_cast(state) + ->markPessimisticFixpoint(); + }); } private: @@ -116,7 +112,7 @@ /// parent operation itself. void visitRegionBranchOperation(ProgramPoint point, RegionBranchOpInterface branch, - AbstractDenseLattice *after); + AbstractDenseElement *after); }; //===----------------------------------------------------------------------===// @@ -128,29 +124,29 @@ /// transfer functions for operations. /// /// `StateT` is expected to be a subclass of `AbstractDenseLattice`. -template +template class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis { public: using AbstractDenseDataFlowAnalysis::AbstractDenseDataFlowAnalysis; /// Visit an operation with the dense lattice before its execution. This /// function is expected to set the dense lattice after its execution. - virtual void visitOperation(Operation *op, const LatticeT &before, - LatticeT *after) = 0; + virtual void visitOperation(Operation *op, const StateT &before, + typename StateT::ElementT *after) = 0; protected: /// Get the dense lattice after this program point. - LatticeT *getLattice(ProgramPoint point) override { - return getOrCreate(point); + typename StateT::ElementT *getLattice(ProgramPoint point) override { + return getOrCreate(point); } private: /// Type-erased wrappers that convert the abstract dense lattice to a derived /// lattice and invoke the virtual hooks operating on the derived lattice. - void visitOperationImpl(Operation *op, const AbstractDenseLattice &before, - AbstractDenseLattice *after) override { - visitOperation(op, static_cast(before), - static_cast(after)); + void visitOperationImpl(Operation *op, const AbstractDenseState &before, + AbstractDenseElement *after) override { + visitOperation(op, static_cast(before), + static_cast(after)); } }; diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -15,6 +15,7 @@ #ifndef MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H #define MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Interfaces/InferIntRangeInterface.h" @@ -27,7 +28,7 @@ /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) /// range that is used to mark the value as unable to be analyzed further, /// where `t` is the type of `value`. - static IntegerValueRange getPessimisticValueState(Value value); + static IntegerValueRange getPessimisticValue(Value value); /// Create an integer value range lattice value. IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {} @@ -45,6 +46,10 @@ const IntegerValueRange &rhs) { return lhs.value.rangeUnion(rhs.value); } + static IntegerValueRange meet(const IntegerValueRange &lhs, + const IntegerValueRange &rhs) { + return lhs.value.intersection(rhs.value); + } /// Print the integer value range. void print(raw_ostream &os) const { os << value; } @@ -57,38 +62,48 @@ /// This lattice element represents the integer value range of an SSA value. /// When this lattice is updated, it automatically updates the constant value /// of the SSA value (if the range can be narrowed to one). -class IntegerValueRangeLattice : public Lattice { +class IntegerValueRangeState : public OptimisticSparseState { public: - using Lattice::Lattice; - - /// If the range can be narrowed to an integer constant, update the constant - /// value of the SSA value. - void onUpdate(DataFlowSolver *solver) const override; + using OptimisticSparseState::OptimisticSparseState; + using ElementT = SparseElement; }; /// Integer range analysis determines the integer value range of SSA values /// using operations that define `InferIntRangeInterface` and also sets the /// range of iteration indices of loops with known bounds. class IntegerRangeAnalysis - : public SparseDataFlowAnalysis { + : public SparseDataFlowAnalysis { public: using SparseDataFlowAnalysis::SparseDataFlowAnalysis; /// Visit an operation. Invoke the transfer function on each operation that /// implements `InferIntRangeInterface`. - void visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) override; + void + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; /// Visit block arguments or operation results of an operation with region /// control-flow for which values are not defined by region control-flow. This /// function calls `InferIntRangeInterface` to provide values for block /// arguments or tries to reduce the range on loop induction variables with /// known bounds. - void - visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, - ArrayRef argLattices, - unsigned firstIndex) override; + void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef argLattices, + unsigned firstIndex) override; +}; + +class IntegerRangeToConstant : public DataFlowAnalysis { +public: + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override; + LogicalResult visit(ProgramPoint point) override; + + bool staticallyProvides(TypeID stateID, ProgramPoint point) const override { + return stateID == TypeID::get() && point.is(); + } }; } // end namespace dataflow diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -23,67 +23,87 @@ namespace dataflow { //===----------------------------------------------------------------------===// -// AbstractSparseLattice +// AbstractSparseState //===----------------------------------------------------------------------===// -/// This class represents an abstract lattice. A lattice contains information -/// about an SSA value and is what's propagated across the IR by sparse -/// data-flow analysis. -class AbstractSparseLattice : public AnalysisState { +class AbstractSparseState : public AbstractState { public: - /// Lattices can only be created for values. - AbstractSparseLattice(Value value) : AnalysisState(value) {} + /// Join the information contained in 'rhs' into this state. Returns + /// if the value of the state changed. + virtual ChangeResult join(const AbstractSparseState &rhs) = 0; - /// Join the information contained in 'rhs' into this lattice. Returns - /// if the value of the lattice changed. - virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0; - - /// Returns true if the lattice element is at fixpoint and further calls to - /// `join` will not update the value of the element. + /// Returns true if the lattice state is at fixpoint and further calls to + /// `join` will not update the value of the state. virtual bool isAtFixpoint() const = 0; - /// Mark the lattice element as having reached a pessimistic fixpoint. This - /// means that the lattice may potentially have conflicting value states, and - /// only the most conservative value should be relied on. + /// Mark the lattice state as having reached a pessimistic fixpoint. This + /// means that the lattice may potentially have an overdefined or underdefined + /// value state, and only the most conservative value should be relied on. virtual ChangeResult markPessimisticFixpoint() = 0; - /// When the lattice gets updated, propagate an update to users of the value - /// using its use-def chain to subscribed analyses. - void onUpdate(DataFlowSolver *solver) const override; + /// Returns true if the value of this lattice hasn't yet been initialized. + virtual bool isUninitialized() const = 0; +}; + +//===----------------------------------------------------------------------===// +// AbstractSparseElement +//===----------------------------------------------------------------------===// + +class AbstractSparseElement : public AbstractElement { +public: + /// Sparse elements can only be created on SSA values. + explicit AbstractSparseElement(DataFlowSolver &solver, Value value) + : AbstractElement(solver, value) {} - /// Subscribe an analysis to updates of the lattice. When the lattice changes, - /// subscribed analyses are re-invoked on all users of the value. This is - /// more efficient than relying on the dependency map. - void useDefSubscribe(DataFlowAnalysis *analysis) { + virtual void useDefSubscribe(DataFlowAnalysis *analysis) = 0; + + virtual const AbstractSparseState *get() const override = 0; +}; + +/// This class represents a sparse analysis element. A sparse element is +/// attached to an SSA value and can track its dependents through the value's +/// use-def chain. This is useful for improving the performance of sparse +/// analyses where users are always dependents of SSA value elements. +template class BaseT> +class SparseElement : public BaseT { +public: + using BaseT::BaseT; + + /// When the sparse element gets updated, propagate an update to users of the + /// value using its use-def chain to subscribed analyses. + void onUpdate() override { + for (Operation *user : this->point.template get().getUsers()) + for (DataFlowAnalysis *analysis : useDefSubscribers) + this->solver.enqueue({user, analysis}); + } + + /// Subscribe an analysis to updates of the sparse element. When the element + /// changes, subscribed analyses are re-invoked on all users of the value. + /// This is more efficient than relying on the dependency map. + void useDefSubscribe(DataFlowAnalysis *analysis) override { useDefSubscribers.insert(analysis); } private: - /// A set of analyses that should be updated when this lattice changes. + /// A set of analyses that should be updated when this element changes. SetVector, SmallPtrSet> useDefSubscribers; }; //===----------------------------------------------------------------------===// -// Lattice +// OptimisticSparseState //===----------------------------------------------------------------------===// -/// This class represents a lattice holding a specific value of type `ValueT`. -/// Lattice values (`ValueT`) are required to adhere to the following: -/// -/// * static ValueT join(const ValueT &lhs, const ValueT &rhs); -/// - This method conservatively joins the information held by `lhs` -/// and `rhs` into a new value. This method is required to be monotonic. -/// * bool operator==(const ValueT &rhs) const; -/// +/// This class represents a sparse state that has an optimistic and known value. +/// This class should be used when the overdefined/underdefined value state is +/// not finitely representable. template -class Lattice : public AbstractSparseLattice { +class OptimisticSparseState : public AbstractSparseState { public: - /// Construct a lattice with a known value. - explicit Lattice(Value value) - : AbstractSparseLattice(value), - knownValue(ValueT::getPessimisticValueState(value)) {} + template + explicit OptimisticSparseState(PointT point) + : knownValue(ValueT::getPessimisticValue(point)) {} /// Return the value held by this lattice. This requires that the value is /// initialized. @@ -92,16 +112,11 @@ return *optimisticValue; } const ValueT &getValue() const { - return const_cast *>(this)->getValue(); + return const_cast *>(this)->getValue(); } /// Returns true if the value of this lattice hasn't yet been initialized. bool isUninitialized() const override { return !optimisticValue.hasValue(); } - /// Force the initialization of the element by setting it to its pessimistic - /// fixpoint. - ChangeResult defaultInitialize() override { - return markPessimisticFixpoint(); - } /// Returns true if the lattice has reached a fixpoint. A fixpoint is when /// the information optimistically assumed to be true is the same as the @@ -110,9 +125,8 @@ /// Join the information contained in the 'rhs' lattice into this /// lattice. Returns if the state of the current lattice changed. - ChangeResult join(const AbstractSparseLattice &rhs) override { - const Lattice &rhsLattice = - static_cast &>(rhs); + ChangeResult join(const AbstractSparseState &rhs) override { + auto &rhsLattice = static_cast &>(rhs); // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do. if (isAtFixpoint() || rhsLattice.isUninitialized()) @@ -122,6 +136,21 @@ return join(rhsLattice.getValue()); } + ChangeResult meet(const OptimisticSparseState &rhs) { + if (isUninitialized()) + return ChangeResult::NoChange; + if (rhs.isUninitialized()) { + optimisticValue.reset(); + return ChangeResult::Change; + } + ValueT newValue = ValueT::meet(getValue(), rhs.getValue()); + if (newValue == optimisticValue) + return ChangeResult::NoChange; + + optimisticValue = newValue; + return ChangeResult::Change; + } + /// Join the information contained in the 'rhs' value into this /// lattice. Returns if the state of the current lattice changed. ChangeResult join(const ValueT &rhs) { @@ -159,16 +188,14 @@ return ChangeResult::Change; } - /// Print the lattice element. void print(raw_ostream &os) const override { - os << "["; + os << '['; knownValue.print(os); - os << ", "; - if (optimisticValue) + if (optimisticValue) { + os << ", "; optimisticValue->print(os); - else - os << ""; - os << "]"; + } + os << ']'; } private: @@ -206,10 +233,10 @@ /// The operation transfer function. Given the operand lattices, this /// function is expected to set the result lattices. - virtual void - visitOperationImpl(Operation *op, - ArrayRef operandLattices, - ArrayRef resultLattices) = 0; + virtual void visitOperationImpl( + Operation *op, + ArrayRef operandLattices, + ArrayRef resultLattices) = 0; /// Given an operation with region control-flow, the lattices of the operands, /// and a region successor, compute the lattice values for block arguments @@ -217,26 +244,29 @@ /// of loops). virtual void visitNonControlFlowArgumentsImpl( Operation *op, const RegionSuccessor &successor, - ArrayRef argLattices, unsigned firstIndex) = 0; + ArrayRef argLattices, + unsigned firstIndex) = 0; /// Get the lattice element of a value. - virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; + virtual AbstractSparseElement *getLatticeElement(Value value) = 0; /// Get a read-only lattice element for a value and add it as a dependency to /// a program point. - const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point, - Value value); + const AbstractSparseState *getLatticeElementFor(ProgramPoint point, + Value value); /// Mark a lattice element as having reached its pessimistic fixpoint and /// propgate an update if changed. - void markPessimisticFixpoint(AbstractSparseLattice *lattice); + void markPessimisticFixpoint(AbstractSparseElement *element); /// Mark the given lattice elements as having reached their pessimistic /// fixpoints and propagate an update if any changed. - void markAllPessimisticFixpoint(ArrayRef lattices); + void markAllPessimisticFixpoint( + ArrayRef elements); /// Join the lattice element and propagate and update if it changed. - void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); + void join(AbstractSparseElement *lhs, + const AbstractSparseState &rhs); private: /// Recursively initialize the analysis on nested operations and blocks. @@ -255,9 +285,10 @@ /// operation `branch`, which can either be the entry block of one of the /// regions or the parent operation itself, and set either the argument or /// parent result lattices. - void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch, - Optional successorIndex, - ArrayRef lattices); + void + visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch, + Optional successorIndex, + ArrayRef elements); }; //===----------------------------------------------------------------------===// @@ -267,7 +298,7 @@ /// A sparse (forward) data-flow analysis for propagating SSA value lattices /// across the IR by implementing transfer functions for operations. /// -/// `StateT` is expected to be a subclass of `AbstractSparseLattice`. +/// `StateT` is expected to be a subclass of `AbstractSparseState`. template class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis { public: @@ -276,8 +307,9 @@ /// Visit an operation with the lattices of its operands. This function is /// expected to set the lattices of the operation's results. - virtual void visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) = 0; + virtual void + visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) = 0; /// Given an operation with possible region control-flow, the lattices of the /// operands, and a region successor, compute the lattice values for block @@ -285,18 +317,21 @@ /// the bounds of loops). By default, this method marks all such lattice /// elements as having reached a pessimistic fixpoint. `firstIndex` is the /// index of the first element of `argLattices` that is set by control-flow. - virtual void visitNonControlFlowArguments(Operation *op, - const RegionSuccessor &successor, - ArrayRef argLattices, - unsigned firstIndex) { + virtual void visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef argLattices, unsigned firstIndex) { markAllPessimisticFixpoint(argLattices.take_front(firstIndex)); markAllPessimisticFixpoint(argLattices.drop_front( firstIndex + successor.getSuccessorInputs().size())); } protected: + bool staticallyProvides(TypeID stateID, ProgramPoint point) const override { + return stateID == TypeID::get() && point.is(); + } + /// Get the lattice element for a value. - StateT *getLatticeElement(Value value) override { + typename StateT::ElementT *getLatticeElement(Value value) override { return getOrCreate(value); } @@ -309,32 +344,37 @@ /// Mark the lattice elements of a range of values as having reached their /// pessimistic fixpoint. - void markAllPessimisticFixpoint(ArrayRef lattices) { + void + markAllPessimisticFixpoint(ArrayRef elements) { AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( - {reinterpret_cast(lattices.begin()), - lattices.size()}); + {reinterpret_cast( + elements.begin()), + elements.size()}); } private: /// Type-erased wrappers that convert the abstract lattice operands to derived /// lattices and invoke the virtual hooks operating on the derived lattices. void visitOperationImpl( - Operation *op, ArrayRef operandLattices, - ArrayRef resultLattices) override { + Operation *op, + ArrayRef operandLattices, + ArrayRef resultLattices) override { visitOperation( op, {reinterpret_cast(operandLattices.begin()), operandLattices.size()}, - {reinterpret_cast(resultLattices.begin()), + {reinterpret_cast( + resultLattices.begin()), resultLattices.size()}); } void visitNonControlFlowArgumentsImpl( Operation *op, const RegionSuccessor &successor, - ArrayRef argLattices, + ArrayRef argLattices, unsigned firstIndex) override { visitNonControlFlowArguments( op, successor, - {reinterpret_cast(argLattices.begin()), + {reinterpret_cast( + argLattices.begin()), argLattices.size()}, firstIndex); } diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -45,9 +45,6 @@ return lhs == ChangeResult::NoChange ? lhs : rhs; } -/// Forward declare the analysis state class. -class AnalysisState; - //===----------------------------------------------------------------------===// // GenericProgramPoint //===----------------------------------------------------------------------===// @@ -178,6 +175,8 @@ // DataFlowSolver //===----------------------------------------------------------------------===// +class AbstractElement; + /// The general data-flow analysis solver. This class is responsible for /// orchestrating child data-flow analyses, running the fixed-point iteration /// algorithm, managing analysis state and program point memory, and tracking @@ -202,16 +201,19 @@ /// operation and run the analysis until fixpoint. LogicalResult initializeAndRun(Operation *top); - /// Lookup an analysis state for the given program point. Returns null if one - /// does not exist. template - const StateT *lookupState(PointT point) const { - auto it = analysisStates.find({ProgramPoint(point), TypeID::get()}); - if (it == analysisStates.end()) + const StateT *lookup(PointT point) const { + using ElementT = typename StateT::ElementT; + auto it = elements.find({TypeID::get(), ProgramPoint(point)}); + if (it == elements.end()) return nullptr; - return static_cast(it->second.get()); + return static_cast(*it->second).get(); } + template + typename StateT::ElementT *getOrCreate(PointT point); + +public: /// Get a uniqued program point instance. If one is not present, it is /// created with the provided arguments. template @@ -226,20 +228,9 @@ /// Push a work item onto the worklist. void enqueue(WorkItem item) { worklist.push(std::move(item)); } - /// Get the state associated with the given program point. If it does not - /// exist, create an uninitialized state. - template - StateT *getOrCreateState(PointT point); - - /// Propagate an update to an analysis state if it changed by pushing - /// dependent work items to the back of the queue. - void propagateIfChanged(AnalysisState *state, ChangeResult changed); - - /// Add a dependency to an analysis state on a child analysis and program - /// point. If the state is updated, the child analysis must be invoked on the - /// given program point again. - void addDependency(AnalysisState *state, DataFlowAnalysis *analysis, - ProgramPoint point); + void getStaticProvidersFor( + TypeID stateID, ProgramPoint point, + SmallVectorImpl &staticProviders) const; private: /// The solver's work queue. Work items can be inserted to the front of the @@ -254,78 +245,129 @@ /// points. StorageUniquer uniquer; - /// A type-erased map of program points to associated analysis states for - /// first-class program points. - DenseMap, std::unique_ptr> - analysisStates; + /// A type-erased map of program points to associated analysis states. + DenseMap, + std::unique_ptr> + elements; /// Allow the base child analysis class to access the internals of the solver. friend class DataFlowAnalysis; }; //===----------------------------------------------------------------------===// -// AnalysisState +// AbstractElement //===----------------------------------------------------------------------===// -/// Base class for generic analysis states. Analysis states contain data-flow -/// information that are attached to program points and which evolve as the -/// analysis iterates. -/// -/// This class places no restrictions on the semantics of analysis states beyond -/// these requirements. -/// -/// 1. Querying the state of a program point prior to visiting that point -/// results in uninitialized state. Analyses must be aware of unintialized -/// states. -/// 2. Analysis states can reach fixpoints, where subsequent updates will never -/// trigger a change in the state. -/// 3. Analysis states that are uninitialized can be forcefully initialized to a -/// default value. -class AnalysisState { +class AbstractState { public: - virtual ~AnalysisState(); + virtual ~AbstractState(); - /// Create the analysis state at the given program point. - AnalysisState(ProgramPoint point) : point(point) {} + virtual void print(raw_ostream &os) const = 0; +}; - /// Returns true if the analysis state is uninitialized. - virtual bool isUninitialized() const = 0; +/// Subclasses are required to implement `get` and `update`. +class AbstractElement { +public: + virtual ~AbstractElement(); - /// Force an uninitialized analysis state to initialize itself with a default - /// value. - virtual ChangeResult defaultInitialize() = 0; + explicit AbstractElement(DataFlowSolver &solver, ProgramPoint point) + : solver(solver), point(point) {} - /// Print the contents of the analysis state. - virtual void print(raw_ostream &os) const = 0; + void addDependency(DataFlowAnalysis *analysis, ProgramPoint point); + + virtual const AbstractState *get() const = 0; + virtual void update(DataFlowAnalysis *provider, + function_ref updateFn) = 0; protected: - /// This function is called by the solver when the analysis state is updated - /// to optionally enqueue more work items. For example, if a state tracks - /// dependents through the IR (e.g. use-def chains), this function can be - /// implemented to push those dependents on the worklist. - virtual void onUpdate(DataFlowSolver *solver) const {} - - /// The dependency relations originating from this analysis state. An entry - /// `state -> (analysis, point)` is created when `analysis` queries `state` - /// when updating `point`. - /// - /// When this state is updated, all dependent child analysis invocations are - /// pushed to the back of the queue. Use a `SetVector` to keep the analysis - /// deterministic. - /// - /// Store the dependents on the analysis state for efficiency. - SetVector dependents; + void propagateUpdate(); + + virtual void onUpdate() {} - /// The program point to which the state belongs. + DataFlowSolver &solver; ProgramPoint point; +private: + SetVector, + llvm::SmallDenseSet> + dependents; + #if LLVM_ENABLE_ABI_BREAKING_CHECKS - /// When compiling with debugging, keep a name for the analysis state. + /// When compiling with debugging, keep a name for the element. StringRef debugName; #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - /// Allow the framework to access the dependents. - friend class DataFlowSolver; + friend class ::mlir::DataFlowSolver; +}; + +template +class SingleStateElement : public BaseT { +public: + template + explicit SingleStateElement(DataFlowSolver &solver, PointT point) + : BaseT(solver, point), state(point) {} + + const StateT *get() const override { return &state; } + + void update(DataFlowAnalysis *provider, + function_ref updateFn) override { + if (updateFn(&state) == ChangeResult::Change) + BaseT::propagateUpdate(); + } + void update(DataFlowAnalysis *provider, + function_ref updateFn) { + return update(provider, function_ref( + [updateFn](AbstractState *state) { + return updateFn(static_cast(state)); + })); + } + +private: + StateT state; +}; + +/// StateT is required to implement `join` and `meet`. +template +class MultiStateElement : public BaseT { +public: + template + explicit MultiStateElement(DataFlowSolver &solver, PointT point) + : BaseT(solver, point), state(point) { + SmallVector staticProviders; + solver.getStaticProvidersFor(TypeID::get(), point, staticProviders); + for (DataFlowAnalysis *staticProvider : staticProviders) + states.try_emplace(staticProvider, StateT(point)); + } + + const StateT *get() const override { return &state; } + + void update(DataFlowAnalysis *provider, + function_ref updateFn) override { + auto it = states.find(provider); + if (it == states.end()) { + if (updateFn(&state) == ChangeResult::Change) + BaseT::propagateUpdate(); + return; + } + if (updateFn(&it->second) == ChangeResult::NoChange) + return; + StateT newState(it->second); + for (auto &entry : states) + (void)newState.meet(entry.second); + if (state.join(newState) == ChangeResult::Change) + BaseT::propagateUpdate(); + } + void update(DataFlowAnalysis *provider, + function_ref updateFn) { + return update(provider, function_ref( + [updateFn](AbstractState *state) { + return updateFn(static_cast(state)); + })); + } + +private: + StateT state; + llvm::SmallDenseMap states; }; //===----------------------------------------------------------------------===// @@ -385,12 +427,13 @@ virtual LogicalResult visit(ProgramPoint point) = 0; protected: - /// Create a dependency between the given analysis state and program point - /// on this analysis. - void addDependency(AnalysisState *state, ProgramPoint point); - - /// Propagate an update to a state if it changed. - void propagateIfChanged(AnalysisState *state, ChangeResult changed); + /// Returns true if this analysis *statically* provides values for the given + /// state kind for the given program point. This means the analysis will + /// always provide values for this state regardless of the state of the + /// analysis. + virtual bool staticallyProvides(TypeID stateID, ProgramPoint point) const { + return false; + } /// Register a custom program point class. template @@ -404,12 +447,18 @@ return solver.getProgramPoint(std::forward(args)...); } + template + void update(PointT point, function_ref updateFn) { + auto *element = getOrCreate(point); + element->update(this, updateFn); + } + /// Get the analysis state assiocated with the program point. The returned /// state is expected to be "write-only", and any updates need to be /// propagated by `propagateIfChanged`. template - StateT *getOrCreate(PointT point) { - return solver.getOrCreateState(point); + typename StateT::ElementT *getOrCreate(PointT point) { + return solver.getOrCreate(point); } /// Get a read-only analysis state for the given point and create a dependency @@ -417,14 +466,15 @@ /// re-invoked on the dependent. template const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) { - StateT *state = getOrCreate(point); - addDependency(state, dependent); - return state; + auto *element = getOrCreate(point); + element->addDependency(this, dependent); + return element->get(); } #if LLVM_ENABLE_ABI_BREAKING_CHECKS /// When compiling with debugging, keep a name for the analyis. StringRef debugName; + friend class AbstractElement; #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS private: @@ -445,19 +495,21 @@ } template -StateT *DataFlowSolver::getOrCreateState(PointT point) { - std::unique_ptr &state = - analysisStates[{ProgramPoint(point), TypeID::get()}]; - if (!state) { - state = std::unique_ptr(new StateT(point)); -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - state->debugName = llvm::getTypeName(); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +typename StateT::ElementT *DataFlowSolver::getOrCreate(PointT point) { + using ElementT = typename StateT::ElementT; + static_assert(std::is_base_of::value, + "expected an abstract element"); + std::unique_ptr &element = + elements[{TypeID::get(), ProgramPoint(point)}]; + if (!element) { + element = std::unique_ptr(new ElementT(*this, point)); + element->debugName = llvm::getTypeName(); } - return static_cast(state.get()); + return static_cast(element.get()); } -inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) { +inline raw_ostream &operator<<(raw_ostream &os, + const AbstractState &state) { state.print(os); return os; } diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp @@ -30,8 +30,8 @@ //===----------------------------------------------------------------------===// void SparseConstantPropagation::visitOperation( - Operation *op, ArrayRef *> operands, - ArrayRef *> results) { + Operation *op, ArrayRef operands, + ArrayRef results) { LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); // Don't try to simulate the results of a region operation as we can't @@ -39,12 +39,12 @@ // folds as the desire here is for simulated execution, and not general // folding. if (op->getNumRegions()) - return; + return markAllPessimisticFixpoint(results); SmallVector constantOperands; constantOperands.reserve(op->getNumOperands()); - for (auto *operandLattice : operands) - constantOperands.push_back(operandLattice->getValue().getConstantValue()); + for (auto *operandState : operands) + constantOperands.push_back(operandState->getValue().getConstantValue()); // Save the original operands and attributes just in case the operation // folds in-place. The constant passed in may not correspond to the real @@ -56,10 +56,8 @@ // fails or was an in-place fold, mark the results as overdefined. SmallVector foldResults; foldResults.reserve(op->getNumResults()); - if (failed(op->fold(constantOperands, foldResults))) { - markAllPessimisticFixpoint(results); - return; - } + if (failed(op->fold(constantOperands, foldResults))) + return markAllPessimisticFixpoint(results); // If the folding was in-place, mark the results as overdefined and reset // the operation. We don't allow in-place folds as the desire here is for @@ -67,25 +65,26 @@ if (foldResults.empty()) { op->setOperands(originalOperands); op->setAttrs(originalAttrs); - return; + return markAllPessimisticFixpoint(results); } // Merge the fold results into the lattice for this operation. assert(foldResults.size() == op->getNumResults() && "invalid result size"); for (const auto it : llvm::zip(results, foldResults)) { - Lattice *lattice = std::get<0>(it); + ConstantValueState::ElementT *element = std::get<0>(it); // Merge in the result of the fold, either a constant or a value. OpFoldResult foldResult = std::get<1>(it); if (Attribute attr = foldResult.dyn_cast()) { LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); - propagateIfChanged(lattice, - lattice->join(ConstantValue(attr, op->getDialect()))); + element->update(this, [attr, op](ConstantValueState *state) { + return state->join(ConstantValue(attr, op->getDialect())); + }); } else { LLVM_DEBUG(llvm::dbgs() << "Folded to value: " << foldResult.get() << "\n"); AbstractSparseDataFlowAnalysis::join( - lattice, *getLatticeElement(foldResult.get())); + element, *getLatticeElement(foldResult.get())->get()); } } } diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -27,24 +27,6 @@ os << (live ? "live" : "dead"); } -void Executable::onUpdate(DataFlowSolver *solver) const { - if (auto *block = point.dyn_cast()) { - // Re-invoke the analyses on the block itself. - for (DataFlowAnalysis *analysis : subscribers) - solver->enqueue({block, analysis}); - // Re-invoke the analyses on all operations in the block. - for (DataFlowAnalysis *analysis : subscribers) - for (Operation &op : *block) - solver->enqueue({&op, analysis}); - } else if (auto *programPoint = point.dyn_cast()) { - // Re-invoke the analysis on the successor block. - if (auto *edge = dyn_cast(programPoint)) { - for (DataFlowAnalysis *analysis : subscribers) - solver->enqueue({edge->getTo(), analysis}); - } - } -} - //===----------------------------------------------------------------------===// // PredecessorState //===----------------------------------------------------------------------===// @@ -104,8 +86,8 @@ for (Region ®ion : top->getRegions()) { if (region.empty()) continue; - auto *state = getOrCreate(®ion.front()); - propagateIfChanged(state, state->setToLive()); + update(®ion.front(), + [](Executable *state) { return state->setToLive(); }); } // Mark as overdefined the predecessors of symbol callables with potentially @@ -132,8 +114,9 @@ // Public symbol callables or those for which we can't see all uses have // potentially unknown callsites. if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { - auto *state = getOrCreate(callable); - propagateIfChanged(state, state->setHasUnknownPredecessors()); + update(callable, [](PredecessorState *state) { + return state->setHasUnknownPredecessors(); + }); } foundSymbolCallable = true; } @@ -149,8 +132,9 @@ // If we couldn't gather the symbol uses, conservatively assume that // we can't track information for any nested symbols. return top->walk([&](CallableOpInterface callable) { - auto *state = getOrCreate(callable); - propagateIfChanged(state, state->setHasUnknownPredecessors()); + update(callable, [](PredecessorState *state) { + return state->setHasUnknownPredecessors(); + }); }); } @@ -160,12 +144,12 @@ // If a callable symbol has a non-call use, then we can't be guaranteed to // know all callsites. Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef()); - auto *state = getOrCreate(symbol); - propagateIfChanged(state, state->setHasUnknownPredecessors()); + update(symbol, [](PredecessorState *state) { + return state->setHasUnknownPredecessors(); + }); } }; - SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(), - walkFn); + SymbolTable::walkSymbolTables(top, !top->getBlock(), walkFn); } /// Returns true if the operation terminates a block. It is insufficient to @@ -198,18 +182,17 @@ } void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { - auto *state = getOrCreate(to); - propagateIfChanged(state, state->setToLive()); - auto *edgeState = getOrCreate(getProgramPoint(from, to)); - propagateIfChanged(edgeState, edgeState->setToLive()); + update(to, [](Executable *state) { return state->setToLive(); }); + update(getProgramPoint(from, to), + [](Executable *state) { return state->setToLive(); }); } void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { for (Region ®ion : op->getRegions()) { if (region.empty()) continue; - auto *state = getOrCreate(®ion.front()); - propagateIfChanged(state, state->setToLive()); + update(®ion.front(), + [](Executable *state) { return state->setToLive(); }); } } @@ -221,7 +204,7 @@ return emitError(point.getLoc(), "unknown program point kind"); // If the parent block is not executable, there is nothing to do. - if (!getOrCreate(op->getBlock())->isLive()) + if (!getOrCreate(op->getBlock())->get()->isLive()) return success(); // We have a live call op. Add this as a live predecessor of the callee. @@ -296,25 +279,27 @@ if (isa_and_nonnull(callableOp) && !isExternalCallable(callableOp)) { // Add the live callsite. - auto *callsites = getOrCreate(callableOp); - propagateIfChanged(callsites, callsites->join(call)); + update(callableOp, [call](PredecessorState *state) { + return state->join(call); + }); } else { // Mark this call op's predecessors as overdefined. - auto *predecessors = getOrCreate(call); - propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); + update(call, [](PredecessorState *state) { + return state->setHasUnknownPredecessors(); + }); } } /// Get the constant values of the operands of an operation. If any of the /// constant value lattices are uninitialized, return none to indicate the /// analysis should bail out. -static Optional> getOperandValuesImpl( - Operation *op, - function_ref *(Value)> getLattice) { +static Optional> +getOperandValuesImpl(Operation *op, + function_ref getState) { SmallVector operands; operands.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { - const Lattice *cv = getLattice(operand); + const ConstantValueState *cv = getState(operand); // If any of the operands' values are uninitialized, bail out. if (cv->isUninitialized()) return {}; @@ -325,11 +310,12 @@ Optional> DeadCodeAnalysis::getOperandValues(Operation *op) { - return getOperandValuesImpl(op, [&](Value value) { - auto *lattice = getOrCreate>(value); - lattice->useDefSubscribe(this); - return lattice; - }); + return getOperandValuesImpl( + op, [&](Value value) -> const ConstantValueState * { + auto *element = getOrCreate(value); + element->useDefSubscribe(this); + return element->get(); + }); } void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { @@ -362,13 +348,13 @@ ? &successor.getSuccessor()->front() : ProgramPoint(branch); // Mark the entry block as executable. - auto *state = getOrCreate(point); - propagateIfChanged(state, state->setToLive()); + update(point, + [](Executable *state) { return state->setToLive(); }); // Add the parent op as a predecessor. - auto *predecessors = getOrCreate(point); - propagateIfChanged( - predecessors, - predecessors->join(branch, successor.getSuccessorInputs())); + update( + point, [branch, &successor](PredecessorState *state) { + return state->join(branch, successor.getSuccessorInputs()); + }); } } @@ -385,17 +371,20 @@ // Mark successor region entry blocks as executable and add this op to the // list of predecessors. for (const RegionSuccessor &successor : successors) { - PredecessorState *predecessors; if (Region *region = successor.getSuccessor()) { - auto *state = getOrCreate(®ion->front()); - propagateIfChanged(state, state->setToLive()); - predecessors = getOrCreate(®ion->front()); + update(®ion->front(), + [](Executable *state) { return state->setToLive(); }); + update( + ®ion->front(), [op, &successor](PredecessorState *state) { + return state->join(op, successor.getSuccessorInputs()); + }); } else { // Add this terminator as a predecessor to the parent op. - predecessors = getOrCreate(branch); + update( + branch, [op, &successor](PredecessorState *state) { + return state->join(op, successor.getSuccessorInputs()); + }); } - propagateIfChanged(predecessors, - predecessors->join(op, successor.getSuccessorInputs())); } } @@ -412,12 +401,14 @@ assert(isa(predecessor)); auto *predecessors = getOrCreate(predecessor); if (canResolve) { - propagateIfChanged(predecessors, predecessors->join(op)); + predecessors->update( + this, [op](PredecessorState *state) { return state->join(op); }); } else { // If the terminator is not a return-like, then conservatively assume we // can't resolve the predecessor. - propagateIfChanged(predecessors, - predecessors->setHasUnknownPredecessors()); + predecessors->update(this, [](PredecessorState *state) { + return state->setHasUnknownPredecessors(); + }); } } } diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -45,8 +45,8 @@ return; // Get the dense lattice to update. - AbstractDenseLattice *after = getLattice(op); - if (after->isAtFixpoint()) + AbstractDenseElement *after = getLattice(op); + if (after->get()->isAtFixpoint()) return; // If this op implements region control-flow, then control-flow dictates its @@ -61,14 +61,17 @@ // If not all return sites are known, then conservatively assume we can't // reason about the data-flow. if (!predecessors->allPredecessorsKnown()) - return reset(after); - for (Operation *predecessor : predecessors->getKnownPredecessors()) - join(after, *getLatticeFor(op, predecessor)); - return; + return markPessimisticFixpoint(after); + return update(after, [this, predecessors, op](AbstractDenseState *state) { + ChangeResult result = ChangeResult::NoChange; + for (Operation *predecessor : predecessors->getKnownPredecessors()) + result |= state->join(*getLatticeFor(op, predecessor)); + return result; + }); } // Get the dense state before the execution of the op. - const AbstractDenseLattice *before; + const AbstractDenseState *before; if (Operation *prev = op->getPrevNode()) before = getLatticeFor(op, prev); else @@ -87,8 +90,8 @@ return; // Get the dense lattice to update. - AbstractDenseLattice *after = getLattice(block); - if (after->isAtFixpoint()) + AbstractDenseElement *after = getLattice(block); + if (after->get()->isAtFixpoint()) return; // The dense lattices of entry blocks are set by region control-flow or the @@ -101,15 +104,17 @@ // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. if (!callsites->allPredecessorsKnown()) - return reset(after); - for (Operation *callsite : callsites->getKnownPredecessors()) { - // Get the dense lattice before the callsite. - if (Operation *prev = callsite->getPrevNode()) - join(after, *getLatticeFor(block, prev)); - else - join(after, *getLatticeFor(block, callsite->getBlock())); - } - return; + return markPessimisticFixpoint(after); + return update(after, [this, callsites, block](AbstractDenseState *state) { + ChangeResult result = ChangeResult::NoChange; + for (Operation *callsite : callsites->getKnownPredecessors()) { + if (Operation *prev = callsite->getPrevNode()) + result |= state->join(*getLatticeFor(block, prev)); + else + result |= state->join(*getLatticeFor(block, callsite->getBlock())); + } + return result; + }); } // Check if we can reason about the control-flow. @@ -117,53 +122,62 @@ return visitRegionBranchOperation(block, branch, after); // Otherwise, we can't reason about the data-flow. - return reset(after); + return markPessimisticFixpoint(after); } // Join the state with the state after the block's predecessors. - for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); - it != e; ++it) { - // Skip control edges that aren't executable. - Block *predecessor = *it; - if (!getOrCreateFor( - block, getProgramPoint(predecessor, block)) - ->isLive()) - continue; - - // Merge in the state from the predecessor's terminator. - join(after, *getLatticeFor(block, predecessor->getTerminator())); - } + update(after, [this, block](AbstractDenseState *state) { + ChangeResult result = ChangeResult::NoChange; + for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); + it != e; ++it) { + // Skip control edges that aren't executable. + Block *predecessor = *it; + if (!getOrCreateFor( + block, getProgramPoint(predecessor, block)) + ->isLive()) + continue; + + // Merge in the state from the predecessor's terminator. + result |= + state->join(*getLatticeFor(block, predecessor->getTerminator())); + } + return result; + }); } void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation( ProgramPoint point, RegionBranchOpInterface branch, - AbstractDenseLattice *after) { + AbstractDenseElement *after) { // Get the terminator predecessors. const auto *predecessors = getOrCreateFor(point, point); assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); - for (Operation *op : predecessors->getKnownPredecessors()) { - const AbstractDenseLattice *before; - // If the predecessor is the parent, get the state before the parent. - if (op == branch) { - if (Operation *prev = op->getPrevNode()) - before = getLatticeFor(point, prev); - else - before = getLatticeFor(point, op->getBlock()); - - // Otherwise, get the state after the terminator. - } else { - before = getLatticeFor(point, op); + update(after, [&](AbstractDenseState *state) { + ChangeResult result = ChangeResult::NoChange; + for (Operation *op : predecessors->getKnownPredecessors()) { + const AbstractDenseState *before; + // If the predecessor is the parent, get the state before the parent. + if (op == branch) { + if (Operation *prev = op->getPrevNode()) + before = getLatticeFor(point, prev); + else + before = getLatticeFor(point, op->getBlock()); + + // Otherwise, get the state after the terminator. + } else { + before = getLatticeFor(point, op); + } + result |= state->join(*before); } - join(after, *before); - } + return result; + }); } -const AbstractDenseLattice * +const AbstractDenseState * AbstractDenseDataFlowAnalysis::getLatticeFor(ProgramPoint dependent, ProgramPoint point) { - AbstractDenseLattice *state = getLattice(point); - addDependency(state, dependent); - return state; + AbstractDenseElement *element = getLattice(point); + element->addDependency(this, dependent); + return element->get(); } diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -23,7 +23,7 @@ using namespace mlir; using namespace mlir::dataflow; -IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) { +IntegerValueRange IntegerValueRange::getPessimisticValue(Value value) { unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType()); APInt umin = APInt::getMinValue(width); APInt umax = APInt::getMaxValue(width); @@ -32,30 +32,9 @@ return {{umin, umax, smin, smax}}; } -void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { - Lattice::onUpdate(solver); - - // If the integer range can be narrowed to a constant, update the constant - // value of the SSA value. - Optional constant = getValue().getValue().getConstantValue(); - auto value = point.get(); - auto *cv = solver->getOrCreateState>(value); - if (!constant) - return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint()); - - Dialect *dialect; - if (auto *parent = value.getDefiningOp()) - dialect = parent->getDialect(); - else - dialect = value.getParentBlock()->getParentOp()->getDialect(); - solver->propagateIfChanged( - cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant), - dialect))); -} - void IntegerRangeAnalysis::visitOperation( - Operation *op, ArrayRef operands, - ArrayRef results) { + Operation *op, ArrayRef operands, + ArrayRef results) { // Ignore non-integer outputs - return early if the op has no scalar // integer results bool hasIntegerResult = false; @@ -63,8 +42,9 @@ if (std::get<1>(it).getType().isIntOrIndex()) { hasIntegerResult = true; } else { - propagateIfChanged(std::get<0>(it), - std::get<0>(it)->markPessimisticFixpoint()); + std::get<0>(it)->update(this, [](IntegerValueRangeState *state) { + return state->markPessimisticFixpoint(); + }); } } if (!hasIntegerResult) @@ -76,7 +56,7 @@ LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); SmallVector argRanges( - llvm::map_range(operands, [](const IntegerValueRangeLattice *val) { + llvm::map_range(operands, [](const IntegerValueRangeState *val) { return val->getValue().getValue(); })); @@ -87,26 +67,28 @@ assert(llvm::find(op->getResults(), result) != op->result_end()); LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); - IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; - Optional oldRange; - if (!lattice->isUninitialized()) - oldRange = lattice->getValue(); - - ChangeResult changed = lattice->join(attrs); - - // Catch loop results with loop variant bounds and conservatively make - // them [-inf, inf] so we don't circle around infinitely often (because - // the dataflow analysis in MLIR doesn't attempt to work out trip counts - // and often can't). - bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { - return op->hasTrait(); - }); - if (isYieldedResult && oldRange.hasValue() && - !(lattice->getValue() == *oldRange)) { - LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); - changed |= lattice->markPessimisticFixpoint(); - } - propagateIfChanged(lattice, changed); + results[result.getResultNumber()]->update( + this, [&](IntegerValueRangeState *state) { + Optional oldRange; + if (!state->isUninitialized()) + oldRange = state->getValue(); + + ChangeResult changed = state->join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often + // (because the dataflow analysis in MLIR doesn't attempt to work out + // trip counts and often can't). + bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedResult && oldRange.hasValue() && + !(state->getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + changed |= state->markPessimisticFixpoint(); + } + return changed; + }); }; inferrable.inferResultRanges(argRanges, joinCallback); @@ -114,7 +96,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( Operation *op, const RegionSuccessor &successor, - ArrayRef argLattices, unsigned firstIndex) { + ArrayRef argLattices, + unsigned firstIndex) { if (auto inferrable = dyn_cast(op)) { LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); SmallVector argRanges( @@ -131,25 +114,28 @@ return; LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); - IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; - Optional oldRange; - if (!lattice->isUninitialized()) - oldRange = lattice->getValue(); - - ChangeResult changed = lattice->join(attrs); - - // Catch loop results with loop variant bounds and conservatively make - // them [-inf, inf] so we don't circle around infinitely often (because - // the dataflow analysis in MLIR doesn't attempt to work out trip counts - // and often can't). - bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { - return op->hasTrait(); - }); - if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) { - LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); - changed |= lattice->markPessimisticFixpoint(); - } - propagateIfChanged(lattice, changed); + argLattices[arg.getArgNumber()]->update( + this, [&](IntegerValueRangeState *state) { + Optional oldRange; + if (!state->isUninitialized()) + oldRange = state->getValue(); + + ChangeResult changed = state->join(attrs); + + // Catch loop results with loop variant bounds and conservatively + // make them [-inf, inf] so we don't circle around infinitely often + // (because the dataflow analysis in MLIR doesn't attempt to work + // out trip counts and often can't). + bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedValue && oldRange && + !(state->getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + changed |= state->markPessimisticFixpoint(); + } + return changed; + }); }; inferrable.inferResultRanges(argRanges, joinCallback); @@ -168,11 +154,9 @@ loopBound->get().dyn_cast_or_null()) return bound.getValue(); } else if (auto value = loopBound->dyn_cast()) { - const IntegerValueRangeLattice *lattice = - getLatticeElementFor(op, value); - if (lattice != nullptr) - return getUpper ? lattice->getValue().getValue().smax() - : lattice->getValue().getValue().smin(); + const IntegerValueRangeState *state = getLatticeElementFor(op, value); + return getUpper ? state->getValue().getValue().smax() + : state->getValue().getValue().smin(); } } // Given the results of getConstant{Lower,Upper}Bound() @@ -192,13 +176,10 @@ Optional lowerBound = loop.getSingleLowerBound(); Optional upperBound = loop.getSingleUpperBound(); Optional step = loop.getSingleStep(); - APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), - /*getUpper=*/false); - APInt max = getLoopBoundFromFold(upperBound, iv->getType(), - /*getUpper=*/true); + APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), false); + APInt max = getLoopBoundFromFold(upperBound, iv->getType(), true); // Assume positivity for uniscoverable steps by way of getUpper = true. - APInt stepVal = - getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true); + APInt stepVal = getLoopBoundFromFold(step, iv->getType(), true); if (stepVal.isNegative()) { std::swap(min, max); @@ -208,12 +189,53 @@ max -= 1; } - IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); + auto *ivEntry = getLatticeElement(*iv); auto ivRange = ConstantIntRanges::fromSigned(min, max); - propagateIfChanged(ivEntry, ivEntry->join(ivRange)); - return; + return ivEntry->update(this, [&ivRange](IntegerValueRangeState *state) { + return state->join(ivRange); + }); } return SparseDataFlowAnalysis::visitNonControlFlowArguments( op, successor, argLattices, firstIndex); } + +LogicalResult IntegerRangeToConstant::initialize(Operation *top) { + auto visitValues = [this](ValueRange values) { + for (Value value : values) + (void)visit(value); + }; + top->walk([&](Operation *op) { + visitValues(op->getResults()); + for (Region ®ion : op->getRegions()) + for (Block &block : region) + visitValues(block.getArguments()); + }); + return success(); +} + +LogicalResult IntegerRangeToConstant::visit(ProgramPoint point) { + auto value = point.get(); + auto *rangeState = getOrCreateFor(value, value); + if (rangeState->isUninitialized()) + return success(); + + update(value, [&](ConstantValueState *state) { + const ConstantIntRanges &range = rangeState->getValue().getValue(); + // Try to narrow to a constant. + Optional constant = range.getConstantValue(); + if (!constant) + return state->markPessimisticFixpoint(); + + // Find a dialect to materialize the constant. + Dialect *dialect; + if (Operation *op = value.getDefiningOp()) + dialect = op->getDialect(); + else + dialect = value.getParentRegion()->getParentOp()->getDialect(); + + Attribute attr = IntegerAttr::get(value.getType(), *constant); + return state->join(ConstantValue(attr, dialect)); + }); + return success(); +} diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -12,17 +12,6 @@ using namespace mlir; using namespace mlir::dataflow; -//===----------------------------------------------------------------------===// -// AbstractSparseLattice -//===----------------------------------------------------------------------===// - -void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { - // Push all users of the value to the queue. - for (Operation *user : point.get().getUsers()) - for (DataFlowAnalysis *analysis : useDefSubscribers) - solver->enqueue({user, analysis}); -} - //===----------------------------------------------------------------------===// // AbstractSparseDataFlowAnalysis //===----------------------------------------------------------------------===// @@ -80,28 +69,26 @@ return; // If the containing block is not executable, bail out. - if (!getOrCreate(op->getBlock())->isLive()) + if (!getOrCreate(op->getBlock())->get()->isLive()) return; // Get the result lattices. - SmallVector resultLattices; - resultLattices.reserve(op->getNumResults()); + SmallVector resultElements; + resultElements.reserve(op->getNumResults()); // Track whether all results have reached their fixpoint. bool allAtFixpoint = true; for (Value result : op->getResults()) { - AbstractSparseLattice *resultLattice = getLatticeElement(result); - allAtFixpoint &= resultLattice->isAtFixpoint(); - resultLattices.push_back(resultLattice); + AbstractSparseElement *resultElement = getLatticeElement(result); + allAtFixpoint &= resultElement->get()->isAtFixpoint(); + resultElements.push_back(resultElement); } // If all result lattices have reached a fixpoint, there is nothing to do. if (allAtFixpoint) return; // The results of a region branch operation are determined by control-flow. - if (auto branch = dyn_cast(op)) { - return visitRegionSuccessors({branch}, branch, - /*successorIndex=*/llvm::None, resultLattices); - } + if (auto branch = dyn_cast(op)) + return visitRegionSuccessors({branch}, branch, llvm::None, resultElements); // The results of a call operation are determined by the callgraph. if (auto call = dyn_cast(op)) { @@ -109,27 +96,27 @@ // If not all return sites are known, then conservatively assume we can't // reason about the data-flow. if (!predecessors->allPredecessorsKnown()) - return markAllPessimisticFixpoint(resultLattices); + return markAllPessimisticFixpoint(resultElements); for (Operation *predecessor : predecessors->getKnownPredecessors()) - for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) + for (auto it : llvm::zip(predecessor->getOperands(), resultElements)) join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); return; } // Grab the lattice elements of the operands. - SmallVector operandLattices; - operandLattices.reserve(op->getNumOperands()); + SmallVector operandStates; + operandStates.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { - AbstractSparseLattice *operandLattice = getLatticeElement(operand); - operandLattice->useDefSubscribe(this); + AbstractSparseElement *operandElement = getLatticeElement(operand); + operandElement->useDefSubscribe(this); // If any of the operand states are not initialized, bail out. - if (operandLattice->isUninitialized()) + if (operandElement->get()->isUninitialized()) return; - operandLattices.push_back(operandLattice); + operandStates.push_back(operandElement->get()); } // Invoke the operation transfer function. - visitOperationImpl(op, operandLattices, resultLattices); + visitOperationImpl(op, operandStates, resultElements); } void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) { @@ -138,17 +125,17 @@ return; // If the block is not executable, bail out. - if (!getOrCreate(block)->isLive()) + if (!getOrCreate(block)->get()->isLive()) return; // Get the argument lattices. - SmallVector argLattices; - argLattices.reserve(block->getNumArguments()); + SmallVector argElements; + argElements.reserve(block->getNumArguments()); bool allAtFixpoint = true; for (BlockArgument argument : block->getArguments()) { - AbstractSparseLattice *argLattice = getLatticeElement(argument); - allAtFixpoint &= argLattice->isAtFixpoint(); - argLattices.push_back(argLattice); + AbstractSparseElement *argElement = getLatticeElement(argument); + allAtFixpoint &= argElement->get()->isAtFixpoint(); + argElements.push_back(argElement); } // If all argument lattices have reached their fixpoints, then there is // nothing to do. @@ -165,10 +152,10 @@ // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. if (!callsites->allPredecessorsKnown()) - return markAllPessimisticFixpoint(argLattices); + return markAllPessimisticFixpoint(argElements); for (Operation *callsite : callsites->getKnownPredecessors()) { auto call = cast(callsite); - for (auto it : llvm::zip(call.getArgOperands(), argLattices)) + for (auto it : llvm::zip(call.getArgOperands(), argElements)) join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); } return; @@ -177,13 +164,13 @@ // Check if the lattices can be determined from region control flow. if (auto branch = dyn_cast(block->getParentOp())) { return visitRegionSuccessors( - block, branch, block->getParent()->getRegionNumber(), argLattices); + block, branch, block->getParent()->getRegionNumber(), argElements); } // Otherwise, we can't reason about the data-flow. return visitNonControlFlowArgumentsImpl(block->getParentOp(), RegionSuccessor(block->getParent()), - argLattices, /*firstIndex=*/0); + argElements, /*firstIndex=*/0); } // Iterate over the predecessors of the non-entry block. @@ -196,7 +183,7 @@ auto *edgeExecutable = getOrCreate(getProgramPoint(predecessor, block)); edgeExecutable->blockContentSubscribe(this); - if (!edgeExecutable->isLive()) + if (!edgeExecutable->get()->isLive()) continue; // Check if we can reason about the data-flow from the predecessor. @@ -204,7 +191,7 @@ dyn_cast(predecessor->getTerminator())) { SuccessorOperands operands = branch.getSuccessorOperands(it.getSuccessorIndex()); - for (auto &it : llvm::enumerate(argLattices)) { + for (auto &it : llvm::enumerate(argElements)) { if (Value operand = operands[it.index()]) { join(it.value(), *getLatticeElementFor(block, operand)); } else { @@ -214,7 +201,7 @@ } } } else { - return markAllPessimisticFixpoint(argLattices); + return markAllPessimisticFixpoint(argElements); } } } @@ -222,7 +209,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors( ProgramPoint point, RegionBranchOpInterface branch, Optional successorIndex, - ArrayRef lattices) { + ArrayRef elements) { const auto *predecessors = getOrCreateFor(point, point); assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); @@ -242,7 +229,7 @@ if (!operands) { // We can't reason about the data-flow. - return markAllPessimisticFixpoint(lattices); + return markAllPessimisticFixpoint(elements); } ValueRange inputs = predecessors->getSuccessorInputs(op); @@ -250,7 +237,7 @@ "expected the same number of successor inputs as operands"); unsigned firstIndex = 0; - if (inputs.size() != lattices.size()) { + if (inputs.size() != elements.size()) { if (auto *op = point.dyn_cast()) { if (!inputs.empty()) firstIndex = inputs.front().cast().getResultNumber(); @@ -258,7 +245,7 @@ branch, RegionSuccessor( branch->getResults().slice(firstIndex, inputs.size())), - lattices, firstIndex); + elements, firstIndex); } else { if (!inputs.empty()) firstIndex = inputs.front().cast().getArgNumber(); @@ -267,36 +254,42 @@ branch, RegionSuccessor(region, region->getArguments().slice( firstIndex, inputs.size())), - lattices, firstIndex); + elements, firstIndex); } } - for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) + for (auto it : llvm::zip(*operands, elements.drop_front(firstIndex))) join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); } } -const AbstractSparseLattice * +const AbstractSparseState * AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, Value value) { - AbstractSparseLattice *state = getLatticeElement(value); - addDependency(state, point); - return state; + AbstractSparseElement *element = getLatticeElement(value); + element->addDependency(this, point); + return element->get(); } void AbstractSparseDataFlowAnalysis::markPessimisticFixpoint( - AbstractSparseLattice *lattice) { - propagateIfChanged(lattice, lattice->markPessimisticFixpoint()); + AbstractSparseElement *element) { + element->update(this, [](AbstractState *state) { + return static_cast(state) + ->markPessimisticFixpoint(); + }); } void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( - ArrayRef lattices) { - for (AbstractSparseLattice *lattice : lattices) { - markPessimisticFixpoint(lattice); - } + ArrayRef elements) { + for (AbstractSparseElement *element : elements) + markPessimisticFixpoint(element); } -void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs, - const AbstractSparseLattice &rhs) { - propagateIfChanged(lhs, lhs->join(rhs)); +void AbstractSparseDataFlowAnalysis::join( + AbstractSparseElement *lhs, + const AbstractSparseState &rhs) { + lhs->update(this, [&rhs](AbstractState *lhsState) { + return static_cast(lhsState)->join(rhs); + }); } + diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -24,12 +24,6 @@ GenericProgramPoint::~GenericProgramPoint() = default; -//===----------------------------------------------------------------------===// -// AnalysisState -//===----------------------------------------------------------------------===// - -AnalysisState::~AnalysisState() = default; - //===----------------------------------------------------------------------===// // ProgramPoint //===----------------------------------------------------------------------===// @@ -58,6 +52,34 @@ return get()->getParent()->getLoc(); } +//===----------------------------------------------------------------------===// +// AbstractState and AbstractElement +//===----------------------------------------------------------------------===// + +AbstractState::~AbstractState() = default; +AbstractElement::~AbstractElement() = default; + +void AbstractElement::addDependency(DataFlowAnalysis *analysis, + ProgramPoint point) { + auto inserted = dependents.insert({point, analysis}); + (void)inserted; + DATAFLOW_DEBUG({ + if (inserted) { + llvm::dbgs() << "Adding dependency from " << debugName << " of " + << this->point << " to " << analysis->debugName << " on " + << point << "\n"; + } + }); +} + +void AbstractElement::propagateUpdate() { + DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << debugName << " of " + << point << "\nValue: " << *get() << "\n"); + for (auto &item : dependents) + solver.enqueue(item); + onUpdate(); +} + //===----------------------------------------------------------------------===// // DataFlowSolver //===----------------------------------------------------------------------===// @@ -94,30 +116,12 @@ return success(); } -void DataFlowSolver::propagateIfChanged(AnalysisState *state, - ChangeResult changed) { - if (changed == ChangeResult::Change) { - DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName - << " of " << state->point << "\n" - << "Value: " << *state << "\n"); - for (const WorkItem &item : state->dependents) - enqueue(item); - state->onUpdate(this); - } -} - -void DataFlowSolver::addDependency(AnalysisState *state, - DataFlowAnalysis *analysis, - ProgramPoint point) { - auto inserted = state->dependents.insert({point, analysis}); - (void)inserted; - DATAFLOW_DEBUG({ - if (inserted) { - llvm::dbgs() << "Creating dependency between " << state->debugName - << " of " << state->point << "\nand " << analysis->debugName - << " on " << point << "\n"; - } - }); +void DataFlowSolver::getStaticProvidersFor( + TypeID stateID, ProgramPoint point, + SmallVectorImpl &staticProviders) const { + for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) + if (analysis.staticallyProvides(stateID, point)) + staticProviders.push_back(&analysis); } //===----------------------------------------------------------------------===// @@ -127,12 +131,3 @@ DataFlowAnalysis::~DataFlowAnalysis() = default; DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {} - -void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) { - solver.addDependency(state, this, point); -} - -void DataFlowAnalysis::propagateIfChanged(AnalysisState *state, - ChangeResult changed) { - solver.propagateIfChanged(state, changed); -} diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp @@ -23,8 +23,8 @@ /// bound on its value (if it is treated as signed) and that bound is /// non-negative. static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { - auto *result = solver.lookupState(v); - if (!result) + auto *result = solver.lookup(v); + if (!result || result->isUninitialized()) return failure(); const ConstantIntRanges &range = result->getValue().getValue(); return success(range.smin().isNonNegative()); @@ -113,6 +113,7 @@ DataFlowSolver solver; solver.load(); solver.load(); + solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -37,7 +37,7 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &builder, OperationFolder &folder, Value value) { - auto *lattice = solver.lookupState>(value); + auto *lattice = solver.lookup(value); if (!lattice || lattice->isUninitialized()) return failure(); const ConstantValue &latticeValue = lattice->getValue(); diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp @@ -29,7 +29,7 @@ os << " "; block.printAsOperand(os); os << " = "; - auto *live = solver.lookupState(&block); + auto *live = solver.lookup(&block); if (live) os << *live; else @@ -39,7 +39,7 @@ os << " from "; pred->printAsOperand(os); os << " = "; - auto *live = solver.lookupState( + auto *live = solver.lookup( solver.getProgramPoint(pred, &block)); if (live) os << *live; @@ -49,12 +49,12 @@ } } if (!region.empty()) { - auto *preds = solver.lookupState(®ion.front()); + auto *preds = solver.lookup(®ion.front()); if (preds) os << "region_preds: " << *preds << "\n"; } } - auto *preds = solver.lookupState(op); + auto *preds = solver.lookup(op); if (preds) os << "op_preds: " << *preds << "\n"; }); @@ -79,9 +79,10 @@ Operation *op = point.get(); Attribute value; if (matchPattern(op, m_Constant(&value))) { - auto *constant = getOrCreate>(op->getResult(0)); - propagateIfChanged( - constant, constant->join(ConstantValue(value, op->getDialect()))); + auto *constant = getOrCreate(op->getResult(0)); + constant->update(this, [value, op](ConstantValueState *state) { + return state->join(ConstantValue(value, op->getDialect())); + }); return success(); } markAllPessimisticFixpoint(op->getResults()); @@ -94,9 +95,10 @@ /// pessimistic fixpoint. void markAllPessimisticFixpoint(ValueRange values) { for (Value value : values) { - auto *constantValue = getOrCreate>(value); - propagateIfChanged(constantValue, - constantValue->markPessimisticFixpoint()); + auto *constant = getOrCreate(value); + constant->update(this, [](ConstantValueState *state) { + return state->markPessimisticFixpoint(); + }); } } }; diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp @@ -20,9 +20,7 @@ class UnderlyingValue { public: /// The pessimistic underlying value of a value is itself. - static UnderlyingValue getPessimisticValueState(Value value) { - return {value}; - } + static UnderlyingValue getPessimisticValue(Value value) { return {value}; } /// Create an underlying value state with a known underlying value. UnderlyingValue(Value underlyingValue = {}) @@ -51,21 +49,20 @@ /// This lattice represents, for a given memory resource, the potential last /// operations that modified the resource. -class LastModification : public AbstractDenseLattice { +class LastModification : public AbstractDenseState { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) - using AbstractDenseLattice::AbstractDenseLattice; + using ElementT = SingleStateElement; + + explicit LastModification(ProgramPoint point) {} /// The lattice is always initialized. bool isUninitialized() const override { return false; } - /// Initialize the lattice. Does nothing. - ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } - /// Mark the lattice as having reached its pessimistic fixpoint. That is, the /// last modifications of all memory resources are unknown. - ChangeResult reset() override { + ChangeResult markPessimisticFixpoint() override { if (lastMods.empty()) return ChangeResult::NoChange; lastMods.clear(); @@ -76,7 +73,7 @@ bool isAtFixpoint() const override { return false; } /// Join the last modifications. - ChangeResult join(const AbstractDenseLattice &lattice) override { + ChangeResult join(const AbstractDenseState &lattice) override { const auto &rhs = static_cast(lattice); ChangeResult result = ChangeResult::NoChange; for (const auto &mod : rhs.lastMods) { @@ -135,13 +132,15 @@ /// its reaching definitions is set to empty. If the operation writes to a /// resource, then its reaching definition is set to the written value. void visitOperation(Operation *op, const LastModification &before, - LastModification *after) override; + LastModification::ElementT *after) override; }; /// Define the lattice class explicitly to provide a type ID. -struct UnderlyingValueLattice : public Lattice { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice) - using Lattice::Lattice; +struct UnderlyingValueState : public OptimisticSparseState { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueState) + + using OptimisticSparseState::OptimisticSparseState; + using ElementT = SparseElement; }; /// An analysis that uses forwarding of values along control-flow and callgraph @@ -149,14 +148,14 @@ /// analysis exists so that the test analysis and pass can test the behaviour of /// the dense data-flow analysis on the callgraph. class UnderlyingValueAnalysis - : public SparseDataFlowAnalysis { + : public SparseDataFlowAnalysis { public: using SparseDataFlowAnalysis::SparseDataFlowAnalysis; /// The underlying value of the results of an operation are not known. - void visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) override { + void + visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override { markAllPessimisticFixpoint(results); } }; @@ -165,8 +164,8 @@ /// Look for the most underlying value of a value. static Value getMostUnderlyingValue( Value value, - function_ref getUnderlyingValueFn) { - const UnderlyingValueLattice *underlying; + function_ref getUnderlyingValueFn) { + const UnderlyingValueState *underlying; do { underlying = getUnderlyingValueFn(value); if (!underlying || underlying->isUninitialized()) @@ -181,38 +180,42 @@ void LastModifiedAnalysis::visitOperation(Operation *op, const LastModification &before, - LastModification *after) { + LastModification::ElementT *after) { auto memory = dyn_cast(op); // If we can't reason about the memory effects, then conservatively assume we // can't deduce anything about the last modifications. if (!memory) - return reset(after); + return markPessimisticFixpoint(after); SmallVector effects; memory.getEffects(effects); - ChangeResult result = after->join(before); - for (const auto &effect : effects) { - Value value = effect.getValue(); + after->update(this, [&](LastModification *state) { + ChangeResult result = state->join(before); + for (const auto &effect : effects) { + Value value = effect.getValue(); - // If we see an effect on anything other than a value, assume we can't - // deduce anything about the last modifications. - if (!value) - return reset(after); + // If we see an effect on anything other than a value, assume we can't + // deduce anything about the last modifications. + if (!value) { + result |= state->markPessimisticFixpoint(); + break; + } - value = getMostUnderlyingValue(value, [&](Value value) { - return getOrCreateFor(op, value); - }); - if (!value) - return; + value = getMostUnderlyingValue(value, [&](Value value) { + return getOrCreateFor(op, value); + }); + if (!value) + return ChangeResult::NoChange; - // Nothing to do for reads. - if (isa(effect.getEffect())) - continue; + // Nothing to do for reads. + if (isa(effect.getEffect())) + continue; - result |= after->set(value, op); - } - propagateIfChanged(after, result); + result |= state->set(value, op); + } + return result; + }); } namespace { @@ -240,13 +243,12 @@ if (!tag) return; os << "test_tag: " << tag.getValue() << ":\n"; - const LastModification *lastMods = - solver.lookupState(op); + const auto *lastMods = solver.lookup(op); assert(lastMods && "expected a dense lattice"); for (auto &it : llvm::enumerate(op->getOperands())) { os << " operand #" << it.index() << "\n"; Value value = getMostUnderlyingValue(it.value(), [&](Value value) { - return solver.lookupState(value); + return solver.lookup(value); }); assert(value && "expected an underlying value"); if (Optional> lastMod = diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp --- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp +++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp @@ -14,17 +14,15 @@ namespace { /// This analysis state represents an integer that is XOR'd with other states. -class FooState : public AnalysisState { +class FooState : public AbstractState { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState) - using AnalysisState::AnalysisState; + using ElementT = SingleStateElement; - /// Default-initialize the state to zero. - ChangeResult defaultInitialize() override { return join(0); } + explicit FooState(ProgramPoint point) {} - /// Returns true if the state is uninitialized. - bool isUninitialized() const override { return !state; } + bool isUninitialized() const { return !state; } /// Print the integer value or "none" if uninitialized. void print(raw_ostream &os) const override { @@ -99,7 +97,8 @@ return top->emitError("expected a single region top-level op"); // Initialize the top-level state. - getOrCreate(&top->getRegion(0).front())->join(0); + update(&top->getRegion(0).front(), + [](FooState *state) { return state->join(0); }); // Visit all nested blocks and operations. for (Block &block : top->getRegion(0)) { @@ -130,35 +129,37 @@ // This is the initial state. Let the framework default-initialize it. return; } - FooState *state = getOrCreate(block); - ChangeResult result = ChangeResult::NoChange; - for (Block *pred : block->getPredecessors()) { - // Join the state at the terminators of all predecessors. - const FooState *predState = - getOrCreateFor(block, pred->getTerminator()); - result |= state->join(*predState); - } - propagateIfChanged(state, result); + update(block, [&](FooState *state) { + ChangeResult result = ChangeResult::NoChange; + for (Block *pred : block->getPredecessors()) { + // Join the state at the terminators of all predecessors. + const FooState *predState = + getOrCreateFor(block, pred->getTerminator()); + result |= state->join(*predState); + } + return result; + }); } void FooAnalysis::visitOperation(Operation *op) { - FooState *state = getOrCreate(op); - ChangeResult result = ChangeResult::NoChange; - - // Copy the state across the operation. - const FooState *prevState; - if (Operation *prev = op->getPrevNode()) - prevState = getOrCreateFor(op, prev); - else - prevState = getOrCreateFor(op, op->getBlock()); - result |= state->set(*prevState); - - // Modify the state with the attribute, if specified. - if (auto attr = op->getAttrOfType("foo")) { - uint64_t value = attr.getUInt(); - result |= state->join(value); - } - propagateIfChanged(state, result); + update(op, [&](FooState *state) { + ChangeResult result = ChangeResult::NoChange; + + // Copy the state across the operation. + const FooState *prevState; + if (Operation *prev = op->getPrevNode()) + prevState = getOrCreateFor(op, prev); + else + prevState = getOrCreateFor(op, op->getBlock()); + result |= state->set(*prevState); + + // Modify the state with the attribute, if specified. + if (auto attr = op->getAttrOfType("foo")) { + uint64_t value = attr.getUInt(); + result |= state->join(value); + } + return result; + }); } void TestFooAnalysisPass::runOnOperation() { @@ -175,7 +176,7 @@ auto tag = op->getAttrOfType("tag"); if (!tag) return; - const FooState *state = solver.lookupState(op); + const FooState *state = solver.lookup(op); assert(state && !state->isUninitialized()); os << tag.getValue() << " -> " << state->getValue() << "\n"; }); diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp --- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp +++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp @@ -23,23 +23,15 @@ /// Patterned after SCCP static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b, OperationFolder &folder, Value value) { - auto *maybeInferredRange = - solver.lookupState(value); - if (!maybeInferredRange || maybeInferredRange->isUninitialized()) - return failure(); - const ConstantIntRanges &inferredRange = - maybeInferredRange->getValue().getValue(); - Optional maybeConstValue = inferredRange.getConstantValue(); - if (!maybeConstValue.hasValue()) + auto *constantState = solver.lookup(value); + if (!constantState || constantState->isUninitialized() || + !constantState->getValue().getConstantValue()) return failure(); - Operation *maybeDefiningOp = value.getDefiningOp(); - Dialect *valueDialect = - maybeDefiningOp ? maybeDefiningOp->getDialect() - : value.getParentRegion()->getParentOp()->getDialect(); - Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); - Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, - value.getType(), value.getLoc()); + const ConstantValue &constantValue = constantState->getValue(); + Value constant = folder.getOrCreateConstant( + b, constantValue.getConstantDialect(), constantValue.getConstantValue(), + value.getType(), value.getLoc()); if (!constant) return failure(); @@ -106,6 +98,7 @@ DataFlowSolver solver; solver.load(); solver.load(); + solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); rewrite(solver, op->getContext(), op->getRegions());