diff --git a/mlir/docs/Tutorials/DataFlowAnalysis.md b/mlir/docs/Tutorials/DataFlowAnalysis.md new file mode 100644 --- /dev/null +++ b/mlir/docs/Tutorials/DataFlowAnalysis.md @@ -0,0 +1,293 @@ +# Writing DataFlow Analyses in MLIR + +Writing dataflow analyses in MLIR, or well any compiler, can often seem quite +daunting and/or complex. A dataflow analysis generally involves propagating +information about the IR across various different types of control flow +constructs, of which MLIR has many (Block-based branches, Region-based branches, +CallGraph, etc), and it isn't always clear how best to go about performing the +propagation. To help writing these types of analyses in MLIR, this document +details several utilities that simplify the process and make it a bit more +approachable. + +## Forward Dataflow Analysis + +One type of dataflow analysis is a forward propagation analysis. This type of +analysis, as the name may suggest, propagates information forward (e.g. from +definitions to uses). To provide a bit of concrete context, let's go over +writing a simple forward dataflow analysis in MLIR. Let's say for this analysis +that we want to propagate information about a special "metadata" dictionary +attribute. The contents of this attribute are simply a set of metadata that +describe a specific value, e.g. `metadata = { likes_pizza = true }`. We will +collect the `metadata` for operations in the IR and propagate them about. + +### Lattices + +Before going into how one might setup the analysis itself, it is important to +first introduce the concept of a `Lattice` and how we will use it for the +analysis. A lattice represents all of the possible values or results of the +analysis for a given value. A lattice element holds the set of information +computed by the analysis for a given value, and is what gets propagated across +the IR. For our analysis, this would correspond to the `metadata` dictionary +attribute. + +Regardless of the value held within, every type of lattice contains two special +element states: + +* `uninitialized` + + - The element has not been initialized. + +* `top`/`overdefined`/`unknown` + + - The element encompasses every possible value. + - This is a very conservative state, and essentially means "I can't make + any assumptions about the value, it could be anything" + +These two states are important when merging, or `join`ing as we will refer to it +further in this document, information as part of the analysis. Lattice elements +are `join`ed whenever there are two different source points, such as an argument +to a block with multiple predecessors. One important note about the `join` +operation, is that it is required to be monotonic (see the `join` method in the +example below for more information). This ensures that `join`ing elements is +consistent. The two special states mentioned above have unique properties during +a `join`: + +* `uninitialized` + + - If one of the elements is `uninitialized`, the other element is used. + - `uninitialized` in the context of a `join` essentially means "take the + other thing". + +* `top`/`overdefined`/`unknown` + + - If one of the elements being joined is `overdefined`, the result is + `overdefined`. + +For our analysis in MLIR, we will need to define a class representing the value +held by an element of the lattice used by our dataflow analysis: + +```c++ +/// The value of our lattice represents the inner structure of a DictionaryAttr, +/// for the `metadata`. +struct MetadataLatticeValue { + MetadataLatticeValue() = default; + /// Compute a lattice value from the provided dictionary. + MetadataLatticeValue(DictionaryAttr attr) + : metadata(attr.begin(), attr.end()) {} + + /// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown` + /// state, for our value type. The resultant state should not assume any + /// information about the state of the IR. + static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) { + // The `top`/`overdefined`/`unknown` state is when we know nothing about any + // metadata, i.e. an empty dictionary. + return MetadataLatticeValue(); + } + /// Return a pessimistic value state for our value type using only information + /// about the state of the provided IR. This is similar to the above method, + /// but may produce a slightly more refined result. This is okay, as the + /// information is already encoded as fact in the IR. + static MetadataLatticeValue getPessimisticValueState(Value value) { + // Check to see if the parent operation has metadata. + if (Operation *parentOp = value.getDefiningOp()) { + if (auto metadata = parentOp->getAttrOfType("metadata")) + return MetadataLatticeValue(metadata); + + // If no metadata is present, fallback to the + // `top`/`overdefined`/`unknown` state. + } + return MetadataLatticeValue(); + } + + /// This method conservatively joins the information held by `lhs` and `rhs` + /// into a new value. This method is required to be monotonic. `monotonicity` + /// is implied by the satisfaction of the following axioms: + /// * idempotence: join(x,x) == x + /// * commutativity: join(x,y) == join(y,x) + /// * associativity: join(x,join(y,z)) == join(join(x,y),z) + /// + /// When the above axioms are satisfied, we achieve `monotonicity`: + /// * monotonicity: join(x, join(x,y)) == join(x,y) + static MetadataLatticeValue join(const MetadataLatticeValue &lhs, + const MetadataLatticeValue &rhs) { + // To join `lhs` and `rhs` we will define a simple policy, which is that we + // only keep information that is the same. This means that we only keep + // facts that are true in both. + MetadataLatticeValue result; + for (const auto &lhsIt : lhs) { + // As noted above, we only merge if the values are the same. + auto it = rhs.metadata.find(lhsIt.first); + if (it == rhs.metadata.end() || it->second != lhsIt.second) + continue; + result.insert(lhsIt); + } + return result; + } + + /// A simple comparator that checks to see if this value is equal to the one + /// provided. + bool operator==(const MetadataLatticeValue &rhs) const { + if (metadata.size() != rhs.metadata.size()) + return false; + // Check that the 'rhs' contains the same metadata. + return llvm::all_of(metadata, [&](auto &it) { + return rhs.metadata.count(it.second); + }); + } + + /// Our value represents the combined metadata, which is originally a + /// DictionaryAttr, so we use a map. + DenseMap metadata; +}; +``` + +One interesting thing to note above is that we don't have an explicit method for +the `uninitialized` state. This state is handled by the `LatticeElement` class, +which manages a lattice value for a given IR entity. A quick overview of this +class, and the API that will be interesting to us while writing our analysis, is +shown below: + +```c++ +/// This class represents a lattice element holding a specific value of type +/// `ValueT`. +template +class LatticeElement ... { +public: + /// Return the value held by this element. This requires that a value is + /// known, i.e. not `uninitialized`. + ValueT &getValue(); + const ValueT &getValue() const; + + /// Join the information contained in the 'rhs' element into this + /// element. Returns if the state of the current element changed. + ChangeResult join(const LatticeElement &rhs); + + /// Join the information contained in the 'rhs' value into this + /// lattice. Returns if the state of the current lattice changed. + ChangeResult join(const ValueT &rhs); + + /// Mark the lattice element as having reached a pessimistic fixpoint. This + /// means that the lattice may potentially have conflicting value states, and + /// only the conservatively known value state should be relied on. + ChangeResult markPessimisticFixPoint(); +}; +``` + +With our lattice defined, we can now define the driver that will compute and +propagate our lattice across the IR. + +### ForwardDataflowAnalysis Driver + +The `ForwardDataFlowAnalysis` class represents the driver of the dataflow +analysis, and performs all of the related analysis computation. When defining +our analysis, we will inherit from this class and implement some of its hooks. +Before that, let's look at a quick overview of this class and some of the +important API for our analysis: + +```c++ +/// This class represents the main driver of the forward dataflow analysis. It +/// takes as a template parameter the value type of lattice being computed. +template +class ForwardDataFlowAnalysis : ... { +public: + ForwardDataFlowAnalysis(MLIRContext *context); + + /// Compute the analysis on operations rooted under the given top-level + /// operation. Note that the top-level operation is not visited. + void run(Operation *topLevelOp); + + /// Return the lattice element attached to the given value. If a lattice has + /// not been added for the given value, a new 'uninitialized' value is + /// inserted and returned. + LatticeElement &getLatticeElement(Value value); + + /// Return the lattice element attached to the given value, or nullptr if no + /// lattice element for the value has yet been created. + LatticeElement *lookupLatticeElement(Value value); + + /// Mark all of the lattice elements for the given range of Values as having + /// reached a pessimistic fixpoint. + ChangeResult markAllPessimisticFixPoint(ValueRange values); + +protected: + /// Visit the given operation, and join any necessary analysis state + /// into the lattice elements for the results and block arguments owned by + /// this operation using the provided set of operand lattice elements + /// (all pointer values are guaranteed to be non-null). Returns if any result + /// or block argument value lattice elements changed during the visit. The + /// lattice element for a result or block argument value can be obtained, and + /// join'ed into, by using `getLatticeElement`. + virtual ChangeResult visitOperation( + Operation *op, ArrayRef *> operands) = 0; +}; +``` + +NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis` +contains various other hooks that allow for injecting custom behavior when +applicable. + +The main API that we are responsible for defining is the `visitOperation` +method. This method is responsible for computing new lattice elements for the +results and block arguments owned by the given operation. This is where we will +inject the lattice element computation logic, also known as the transfer +function for the operation, that is specific to our analysis. A simple +implementation for our example is shown below: + +```c++ +class MetadataAnalysis : public ForwardDataFlowAnalysis { +public: + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + + ChangeResult visitOperation( + Operation *op, ArrayRef *> operands) override { + DictionaryAttr metadata = op->getAttrOfType("metadata"); + + // If we have no metadata for this operation, we will conservatively mark + // all of the results as having reached a pessimistic fixpoint. + if (!metadata) + return markAllPessimisticFixPoint(op->getResults()); + + // Otherwise, we will compute a lattice value for the metadata and join it + // into the current lattice element for all of our results. + MetadataLatticeValue latticeValue(metadata); + ChangeResult result = ChangeResult::NoChange; + for (Value value : op->getResults()) { + // We grab the lattice element for `value` via `getLatticeElement` and + // then join it with the lattice value for this operation's metadata. Note + // that during the analysis phase, it is fine to freely create a new + // lattice element for a value. This is why we don't use the + // `lookupLatticeElement` method here. + result |= getLatticeElement(value).join(latticeValue); + } + return result; + } +}; +``` + +With that, we have all of the necessary components to compute our analysis. +After the analysis has been computed, we can grab any computed information for +values by using `lookupLatticeElement`. We use this function over +`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g. +if the value is in a unreachable block, and we don't want to create a new +uninitialized lattice element in this case. See below for a quick example: + +```c++ +void MyPass::runOnOperation() { + MetadataAnalysis analysis(&getContext()); + analysis.run(getOperation()); + ... +} + +void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) { + LatticeElement *latticeElement = analysis.lookupLatticeElement(value); + + // If we don't have an element, the `value` wasn't visited during our analysis + // meaning that it could be dead. We need to treat this conservatively. + if (!lattice) + return; + + // Our lattice element has a value, use it: + MetadataLatticeValue &value = lattice->getValue(); + ... +} +``` diff --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h @@ -0,0 +1,401 @@ +//===- DataFlowAnalysis.h - General DataFlow Analysis Utilities -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This files several utilities and algorithms that perform abstract dataflow +// analysis over the IR. These allow for users to hook into various analysis +// propagation algorithms without needing to reinvent the traveral over the +// different types of control structures present within MLIR, such as regions, +// the callgraph, etc. A few of the main entry points are detailed below: +// +// FowardDataFlowAnalysis: +// This class provides support for defining dataflow algorithms that are +// forward, sparse, pessimistic (except along unreached backedges) and +// context-insensitive for the interprocedural aspects. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H +#define MLIR_ANALYSIS_DATAFLOWANALYSIS_H + +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/Support/Allocator.h" + +namespace mlir { +//===----------------------------------------------------------------------===// +// ChangeResult +//===----------------------------------------------------------------------===// + +/// A result type used to indicate if a change happened. Boolean operations on +/// ChangeResult behave as though `Change` is truthy. +enum class ChangeResult { + NoChange, + Change, +}; +inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) { + return lhs == ChangeResult::Change ? lhs : rhs; +} +inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) { + lhs = lhs | rhs; + return lhs; +} +inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) { + return lhs == ChangeResult::NoChange ? lhs : rhs; +} + +//===----------------------------------------------------------------------===// +// AbstractLatticeElement +//===----------------------------------------------------------------------===// + +namespace detail { +/// This class represents an abstract lattice. A lattice is what gets propagated +/// across the IR, and contains the information for a specific Value. +class AbstractLatticeElement { +public: + virtual ~AbstractLatticeElement(); + + /// Returns true if the value of this lattice is uninitialized, meaning that + /// it hasn't yet been initialized. + virtual bool isUninitialized() const = 0; + + /// Join the information contained in 'rhs' into this lattice. Returns + /// if the value of the lattice changed. + virtual ChangeResult join(const AbstractLatticeElement &rhs) = 0; + + /// Mark the lattice element as having reached a pessimistic fixpoint. This + /// means that the lattice may potentially have conflicting value states, and + /// only the most conservative value should be relied on. + virtual ChangeResult markPessimisticFixpoint() = 0; + + /// Mark the lattice element as having reached an optimistic fixpoint. This + /// means that we optimisticly assume the current value is the true state. + virtual void markOptimisticFixpoint() = 0; + + /// Returns true if the lattice has reached a fixpoint. A fixpoint is when the + /// information optimistically assumed to be true is the same as the + /// information known to be true. + virtual bool isAtFixpoint() const = 0; +}; +} // namespace detail + +//===----------------------------------------------------------------------===// +// LatticeElement +//===----------------------------------------------------------------------===// + +/// This class represents a lattice holding a specific value of type `ValueT`. +/// Lattice values (`ValueT`) are required to adhere to the following: +/// * static ValueT join(const ValueT &lhs, const ValueT &rhs); +/// - This method conservatively joins the information held by `lhs` +/// and `rhs` into a new value. This method is required to be monotonic. +/// * static ValueT getPessimisticValueState(MLIRContext *context); +/// - This method computes a pessimistic/conservative value state assuming +/// no information about the state of the IR. +/// * static ValueT getPessimisticValueState(Value value); +/// - This method computes a pessimistic/conservative value state for +/// `value` assuming only information present in the current IR. +/// * bool operator==(const ValueT &rhs) const; +/// +template +class LatticeElement final : public detail::AbstractLatticeElement { +public: + LatticeElement() = delete; + LatticeElement(const ValueT &knownValue) : knownValue(knownValue) {} + + /// Return the value held by this lattice. This requires that the value is + /// initialized. + ValueT &getValue() { + assert(!isUninitialized() && "expected known lattice element"); + return *optimisticValue; + } + const ValueT &getValue() const { + assert(!isUninitialized() && "expected known lattice element"); + return *optimisticValue; + } + + /// Returns true if the value of this lattice hasn't yet been initialized. + bool isUninitialized() const final { return !optimisticValue.hasValue(); } + + /// Join the information contained in the 'rhs' lattice into this + /// lattice. Returns if the state of the current lattice changed. + ChangeResult join(const detail::AbstractLatticeElement &rhs) final { + const LatticeElement &rhsLattice = + static_cast &>(rhs); + + // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do. + if (isAtFixpoint() || rhsLattice.isUninitialized()) + return ChangeResult::NoChange; + + // Join the rhs value into this lattice. + return join(rhsLattice.getValue()); + } + + /// Join the information contained in the 'rhs' value into this + /// lattice. Returns if the state of the current lattice changed. + ChangeResult join(const ValueT &rhs) { + // If the current lattice is uninitialized, copy the rhs value. + if (isUninitialized()) { + optimisticValue = rhs; + return ChangeResult::Change; + } + + // Otherwise, join rhs with the current optimistic value. + ValueT newValue = ValueT::join(*optimisticValue, rhs); + assert(ValueT::join(newValue, *optimisticValue) == newValue && + "expected `join` to be monotonic"); + assert(ValueT::join(newValue, rhs) == newValue && + "expected `join` to be monotonic"); + + // Update the current optimistic value if something changed. + if (newValue == optimisticValue) + return ChangeResult::NoChange; + + optimisticValue = newValue; + return ChangeResult::Change; + } + + /// Mark the lattice element as having reached a pessimistic fixpoint. This + /// means that the lattice may potentially have conflicting value states, and + /// only the conservatively known value state should be relied on. + ChangeResult markPessimisticFixpoint() final { + if (isAtFixpoint()) + return ChangeResult::NoChange; + + // For this fixed point, we take whatever we knew to be true and set that to + // our optimistic value. + optimisticValue = knownValue; + return ChangeResult::Change; + } + + /// Mark the lattice element as having reached an optimistic fixpoint. This + /// means that we optimisticly assume the current value is the true state. + void markOptimisticFixpoint() final { + assert(!isUninitialized() && "expected an initialized value"); + knownValue = *optimisticValue; + } + + /// Returns true if the lattice has reached a fixpoint. A fixpoint is when the + /// information optimistically assumed to be true is the same as the + /// information known to be true. + bool isAtFixpoint() const final { return optimisticValue == knownValue; } + +private: + /// The value that is conservatively known to be true. + ValueT knownValue; + /// The currently computed value that is optimistically assumed to be true, or + /// None if the lattice element is uninitialized. + Optional optimisticValue; +}; + +//===----------------------------------------------------------------------===// +// ForwardDataFlowAnalysisBase +//===----------------------------------------------------------------------===// + +namespace detail { +/// This class is the non-templated virtual base class for the +/// ForwardDataFlowAnalysis. This class provides opaque hooks to the main +/// alogrithm. +class ForwardDataFlowAnalysisBase { +public: + virtual ~ForwardDataFlowAnalysisBase(); + + /// Initialize and compute the analysis on operations rooted under the given + /// top-level operation. Note that the top-level operation is not visited. + void run(Operation *topLevelOp); + + /// Return the lattice element attached to the given value. If a lattice has + /// not been added for the given value, a new 'uninitialized' value is + /// inserted and returned. + AbstractLatticeElement &getLatticeElement(Value value); + + /// Return the lattice element attached to the given value, or nullptr if no + /// lattice for the value has yet been created. + AbstractLatticeElement *lookupLatticeElement(Value value); + + /// Visit the given operation, and join any necessary analysis state + /// into the lattices for the results and block arguments owned by this + /// operation using the provided set of operand lattice elements (all pointer + /// values are guaranteed to be non-null). Returns if any result or block + /// argument value lattices changed during the visit. The lattice for a result + /// or block argument value can be obtained and join'ed into by using + /// `getLatticeElement`. + virtual ChangeResult + visitOperation(Operation *op, + ArrayRef operands) = 0; + + /// Given a BranchOpInterface, and the current lattice elements that + /// correspond to the branch operands (all pointer values are guaranteed to be + /// non-null), try to compute a specific set of successors that would be + /// selected for the branch. Returns failure if not computable, or if all of + /// the successors would be chosen. If a subset of successors can be selected, + /// `successors` is populated. + virtual LogicalResult + getSuccessorsForOperands(BranchOpInterface branch, + ArrayRef operands, + SmallVectorImpl &successors) = 0; + + /// Given a RegionBranchOpInterface, and the current lattice elements that + /// correspond to the branch operands (all pointer values are guaranteed to be + /// non-null), compute a specific set of region successors that would be + /// selected. + virtual void + getSuccessorsForOperands(RegionBranchOpInterface branch, + Optional sourceIndex, + ArrayRef operands, + SmallVectorImpl &successors) = 0; + + /// Create a new uninitialized lattice element. An optional value is provided + /// which, if valid, should be used to initialize the known conservative state + /// of the lattice. + virtual AbstractLatticeElement *createLatticeElement(Value value = {}) = 0; + +private: + /// A map from SSA value to lattice element. + DenseMap latticeValues; +}; +} // namespace detail + +//===----------------------------------------------------------------------===// +// ForwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// This class provides a general forward dataflow analyis driver +/// utilizing the lattice classes defined above, to enable the easy definition +/// of dataflow analysis algorithms. More specically this driver is useful for +/// defining analyses that are forward, sparse, pessimistic (except along +/// unreached backedges) and context-insensitive for the interprocedural +/// aspects. +template +class ForwardDataFlowAnalysis : public detail::ForwardDataFlowAnalysisBase { +public: + ForwardDataFlowAnalysis(MLIRContext *context) : context(context) {} + + /// Return the MLIR context used when constructing this analysis. + MLIRContext *getContext() { return context; } + + /// Compute the analysis on operations rooted under the given top-level + /// operation. Note that the top-level operation is not visited. + void run(Operation *topLevelOp) { + detail::ForwardDataFlowAnalysisBase::run(topLevelOp); + } + + /// Return the lattice element attached to the given value, or nullptr if no + /// lattice for the value has yet been created. + LatticeElement *lookupLatticeElement(Value value) { + return static_cast *>( + detail::ForwardDataFlowAnalysisBase::lookupLatticeElement(value)); + } + +protected: + /// Return the lattice element attached to the given value. If a lattice has + /// not been added for the given value, a new 'uninitialized' value is + /// inserted and returned. + LatticeElement &getLatticeElement(Value value) { + return static_cast &>( + detail::ForwardDataFlowAnalysisBase::getLatticeElement(value)); + } + + /// Mark all of the lattices for the given range of Values as having reached a + /// pessimistic fixpoint. + ChangeResult markAllPessimisticFixpoint(ValueRange values) { + ChangeResult result = ChangeResult::NoChange; + for (Value value : values) + result |= getLatticeElement(value).markPessimisticFixpoint(); + return result; + } + + /// Visit the given operation, and join any necessary analysis state + /// into the lattices for the results and block arguments owned by this + /// operation using the provided set of operand lattice elements (all pointer + /// values are guaranteed to be non-null). Returns if any result or block + /// argument value lattices changed during the visit. The lattice for a result + /// or block argument value can be obtained by using + /// `getLatticeElement`. + virtual ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) = 0; + + /// Given a BranchOpInterface, and the current lattice elements that + /// correspond to the branch operands (all pointer values are guaranteed to be + /// non-null), try to compute a specific set of successors that would be + /// selected for the branch. Returns failure if not computable, or if all of + /// the successors would be chosen. If a subset of successors can be selected, + /// `successors` is populated. + virtual LogicalResult + getSuccessorsForOperands(BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) { + return failure(); + } + + /// Given a RegionBranchOpInterface, and the current lattice elements that + /// correspond to the branch operands (all pointer values are guaranteed to be + /// non-null), compute a specific set of region successors that would be + /// selected. + virtual void + getSuccessorsForOperands(RegionBranchOpInterface branch, + Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) { + SmallVector constantOperands(operands.size()); + branch.getSuccessorRegions(sourceIndex, constantOperands, successors); + } + +private: + /// Type-erased wrappers that convert the abstract lattice operands to derived + /// lattices and invoke the virtual hooks operating on the derived lattices. + ChangeResult + visitOperation(Operation *op, + ArrayRef operands) final { + LatticeElement *const *derivedOperandBase = + reinterpret_cast *const *>(operands.data()); + return visitOperation( + op, llvm::makeArrayRef(derivedOperandBase, operands.size())); + } + LogicalResult + getSuccessorsForOperands(BranchOpInterface branch, + ArrayRef operands, + SmallVectorImpl &successors) final { + LatticeElement *const *derivedOperandBase = + reinterpret_cast *const *>(operands.data()); + return getSuccessorsForOperands( + branch, llvm::makeArrayRef(derivedOperandBase, operands.size()), + successors); + } + void + getSuccessorsForOperands(RegionBranchOpInterface branch, + Optional sourceIndex, + ArrayRef operands, + SmallVectorImpl &successors) final { + LatticeElement *const *derivedOperandBase = + reinterpret_cast *const *>(operands.data()); + getSuccessorsForOperands( + branch, sourceIndex, + llvm::makeArrayRef(derivedOperandBase, operands.size()), successors); + } + + /// Create a new uninitialized lattice element. An optional value is provided, + /// which if valid, should be used to initialize the known conservative state + /// of the lattice. + detail::AbstractLatticeElement *createLatticeElement(Value value) final { + ValueT knownValue = value ? ValueT::getPessimisticValueState(value) + : ValueT::getPessimisticValueState(context); + return new (allocator.Allocate()) LatticeElement(knownValue); + } + + /// An allocator used for new lattice elements. + llvm::SpecificBumpPtrAllocator> allocator; + + /// The MLIRContext of this solver. + MLIRContext *context; +}; + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DATAFLOWANALYSIS_H diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ AffineStructures.cpp BufferAliasAnalysis.cpp CallGraph.cpp + DataFlowAnalysis.cpp LinearTransform.cpp Liveness.cpp LoopAnalysis.cpp @@ -20,6 +21,7 @@ AliasAnalysis.cpp BufferAliasAnalysis.cpp CallGraph.cpp + DataFlowAnalysis.cpp Liveness.cpp NumberOfExecutions.cpp SliceAnalysis.cpp diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -0,0 +1,780 @@ +//===- DataFlowAnalysis.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/SmallPtrSet.h" + +using namespace mlir; +using namespace mlir::detail; + +namespace { +/// This class contains various state used when computing the lattice elements +/// of a callable operation. +class CallableLatticeState { +public: + /// Build a lattice state with a given callable region, and a specified number + /// of results to be initialized to the default lattice element. + CallableLatticeState(ForwardDataFlowAnalysisBase &analysis, + Region *callableRegion, unsigned numResults) + : callableArguments(callableRegion->getArguments()), + resultLatticeElements(numResults) { + for (AbstractLatticeElement *&it : resultLatticeElements) + it = analysis.createLatticeElement(); + } + + /// Returns the arguments to the callable region. + Block::BlockArgListType getCallableArguments() const { + return callableArguments; + } + + /// Returns the lattice element for the results of the callable region. + auto getResultLatticeElements() { + return llvm::make_pointee_range(resultLatticeElements); + } + + /// Add a call to this callable. This is only used if the callable defines a + /// symbol. + void addSymbolCall(Operation *op) { symbolCalls.push_back(op); } + + /// Return the calls that reference this callable. This is only used + /// if the callable defines a symbol. + ArrayRef getSymbolCalls() const { return symbolCalls; } + +private: + /// The arguments of the callable region. + Block::BlockArgListType callableArguments; + + /// The lattice state for each of the results of this region. The return + /// values of the callable aren't SSA values, so we need to track them + /// separately. + SmallVector resultLatticeElements; + + /// The calls referencing this callable if this callable defines a symbol. + /// This removes the need to recompute symbol references during propagation. + /// Value based references are trivial to resolve, so they can be done + /// in-place. + SmallVector symbolCalls; +}; + +/// This class represents the solver for a forward dataflow analysis. This class +/// acts as the propagation engine for computing which lattice elements. +class ForwardDataFlowSolver { +public: + /// Initialize the solver with the given top-level operation. + ForwardDataFlowSolver(ForwardDataFlowAnalysisBase &analysis, Operation *op); + + /// Run the solver until it converges. + void solve(); + +private: + /// Initialize the set of symbol defining callables that can have their + /// arguments and results tracked. 'op' is the top-level operation that the + /// solver is operating on. + void initializeSymbolCallables(Operation *op); + + /// Visit the users of the given IR that reside within executable blocks. + template + void visitUsers(T &value) { + for (Operation *user : value.getUsers()) + if (isBlockExecutable(user->getBlock())) + visitOperation(user); + } + + /// Visit the given operation and compute any necessary lattice state. + void visitOperation(Operation *op); + + /// Visit the given call operation and compute any necessary lattice state. + void visitCallOperation(CallOpInterface op); + + /// Visit the given callable operation and compute any necessary lattice + /// state. + void visitCallableOperation(Operation *op); + + /// Visit the given region branch operation, which defines regions, and + /// compute any necessary lattice state. This also resolves the lattice state + /// of both the operation results and any nested regions. + void visitRegionBranchOperation( + RegionBranchOpInterface branch, + ArrayRef operandLattices); + + /// Visit the given set of region successors, computing any necessary lattice + /// state. The provided function returns the input operands to the region at + /// the given index. If the index is 'None', the input operands correspond to + /// the parent operation results. + void visitRegionSuccessors( + Operation *parentOp, ArrayRef regionSuccessors, + function_ref)> getInputsForRegion); + + /// Visit the given terminator operation and compute any necessary lattice + /// state. + void + visitTerminatorOperation(Operation *op, + ArrayRef operandLattices); + + /// Visit the given terminator operation that exits a callable region. These + /// are terminators with no CFG successors. + void visitCallableTerminatorOperation( + Operation *callable, Operation *terminator, + ArrayRef operandLattices); + + /// Visit the given block and compute any necessary lattice state. + void visitBlock(Block *block); + + /// Visit argument #'i' of the given block and compute any necessary lattice + /// state. + void visitBlockArgument(Block *block, int i); + + /// Mark the entry block of the given region as executable. Returns NoChange + /// if the block was already marked executable. If `markPessimisticFixpoint` + /// is true, the arguments of the entry block are also marked as having + /// reached the pessimistic fixpoint. + ChangeResult markEntryBlockExecutable(Region *region, + bool markPessimisticFixpoint); + + /// Mark the given block as executable. Returns NoChange if the block was + /// already marked executable. + ChangeResult markBlockExecutable(Block *block); + + /// Returns true if the given block is executable. + bool isBlockExecutable(Block *block) const; + + /// Mark the edge between 'from' and 'to' as executable. + void markEdgeExecutable(Block *from, Block *to); + + /// Return true if the edge between 'from' and 'to' is executable. + bool isEdgeExecutable(Block *from, Block *to) const; + + /// Mark the given value as having reached the pessimistic fixpoint. This + /// means that we cannot further refine the state of this value. + void markPessimisticFixpoint(Value value); + + /// Mark all of the given values as having reaching the pessimistic fixpoint. + template + void markAllPessimisticFixpoint(ValuesT values) { + for (auto value : values) + markPessimisticFixpoint(value); + } + template + void markAllPessimisticFixpoint(Operation *op, ValuesT values) { + markAllPessimisticFixpoint(values); + opWorklist.push_back(op); + } + template + void markAllPessimisticFixpointAndVisitUsers(ValuesT values) { + for (auto value : values) { + AbstractLatticeElement &lattice = analysis.getLatticeElement(value); + if (lattice.markPessimisticFixpoint() == ChangeResult::Change) + visitUsers(value); + } + } + + /// Returns true if the given value was marked as having reached the + /// pessimistic fixpoint. + bool isAtFixpoint(Value value) const; + + /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' + /// corresponds to the parent operation of the lattice for 'to'. + void join(Operation *owner, AbstractLatticeElement &to, + const AbstractLatticeElement &from); + + /// A reference to the dataflow analysis being computed. + ForwardDataFlowAnalysisBase &analysis; + + /// The set of blocks that are known to execute, or are intrinsically live. + SmallPtrSet executableBlocks; + + /// The set of control flow edges that are known to execute. + DenseSet> executableEdges; + + /// A worklist containing blocks that need to be processed. + SmallVector blockWorklist; + + /// A worklist of operations that need to be processed. + SmallVector opWorklist; + + /// The callable operations that have their argument/result state tracked. + DenseMap callableLatticeState; + + /// A map between a call operation and the resolved symbol callable. This + /// avoids re-resolving symbol references during propagation. Value based + /// callables are trivial to resolve, so they can be done in-place. + DenseMap callToSymbolCallable; + + /// A symbol table used for O(1) symbol lookups during simplification. + SymbolTableCollection symbolTable; +}; +} // end anonymous namespace + +ForwardDataFlowSolver::ForwardDataFlowSolver( + ForwardDataFlowAnalysisBase &analysis, Operation *op) + : analysis(analysis) { + /// Initialize the solver with the regions within this operation. + for (Region ®ion : op->getRegions()) { + // Mark the entry block as executable. The values passed to these regions + // are also invisible, so mark any arguments as reaching the pessimistic + // fixpoint. + markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); + } + initializeSymbolCallables(op); +} + +void ForwardDataFlowSolver::solve() { + while (!blockWorklist.empty() || !opWorklist.empty()) { + // Process any operations in the op worklist. + while (!opWorklist.empty()) + visitUsers(*opWorklist.pop_back_val()); + + // Process any blocks in the block worklist. + while (!blockWorklist.empty()) + visitBlock(blockWorklist.pop_back_val()); + } +} + +void ForwardDataFlowSolver::initializeSymbolCallables(Operation *op) { + // Initialize the set of symbol callables that can have their state tracked. + // This tracks which symbol callable operations we can propagate within and + // out of. + auto walkFn = [&](Operation *symTable, bool allUsesVisible) { + Region &symbolTableRegion = symTable->getRegion(0); + Block *symbolTableBlock = &symbolTableRegion.front(); + for (auto callable : symbolTableBlock->getOps()) { + // We won't be able to track external callables. + Region *callableRegion = callable.getCallableRegion(); + if (!callableRegion) + continue; + // We only care about symbol defining callables here. + auto symbol = dyn_cast(callable.getOperation()); + if (!symbol) + continue; + callableLatticeState.try_emplace(callable, analysis, callableRegion, + callable.getCallableResults().size()); + + // If not all of the uses of this symbol are visible, we can't track the + // state of the arguments. + if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { + for (Region ®ion : callable->getRegions()) + markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); + } + } + if (callableLatticeState.empty()) + return; + + // After computing the valid callables, walk any symbol uses to check + // for non-call references. We won't be able to track the lattice state + // for arguments to these callables, as we can't guarantee that we can see + // all of its calls. + Optional uses = + SymbolTable::getSymbolUses(&symbolTableRegion); + if (!uses) { + // If we couldn't gather the symbol uses, conservatively assume that + // we can't track information for any nested symbols. + op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); }); + return; + } + + for (const SymbolTable::SymbolUse &use : *uses) { + // If the use is a call, track it to avoid the need to recompute the + // reference later. + if (auto callOp = dyn_cast(use.getUser())) { + Operation *symCallable = callOp.resolveCallable(&symbolTable); + auto callableLatticeIt = callableLatticeState.find(symCallable); + if (callableLatticeIt != callableLatticeState.end()) { + callToSymbolCallable.try_emplace(callOp, symCallable); + + // We only need to record the call in the lattice if it produces any + // values. + if (callOp->getNumResults()) + callableLatticeIt->second.addSymbolCall(callOp); + } + continue; + } + // This use isn't a call, so don't we know all of the callers. + auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef()); + auto it = callableLatticeState.find(symbol); + if (it != callableLatticeState.end()) { + for (Region ®ion : it->first->getRegions()) + markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); + } + } + }; + SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), + walkFn); +} + +void ForwardDataFlowSolver::visitOperation(Operation *op) { + // Collect all of the lattice elements feeding into this operation. If any are + // not yet resolved, bail out and wait for them to resolve. + SmallVector operandLattices; + operandLattices.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + AbstractLatticeElement *operandLattice = + analysis.lookupLatticeElement(operand); + if (!operandLattice) + return; + operandLattices.push_back(operandLattice); + } + + // If this is a terminator operation, process any control flow lattice state. + if (op->hasTrait()) + visitTerminatorOperation(op, operandLattices); + + // Process call operations. The call visitor processes result values, so we + // can exit afterwards. + if (CallOpInterface call = dyn_cast(op)) + return visitCallOperation(call); + + // Process callable operations. These are specially handled region operations + // that track dataflow via calls. + if (isa(op)) { + // If this callable has a tracked lattice state, it will be visited by calls + // that reference it instead. This way, we don't assume that it is + // executable unless there is a proper reference to it. + if (callableLatticeState.count(op)) + return; + return visitCallableOperation(op); + } + + // Process region holding operations. + if (op->getNumRegions()) { + // Check to see if we can reason about the internal control flow of this + // region operation. + if (auto branch = dyn_cast(op)) + return visitRegionBranchOperation(branch, operandLattices); + + // If we can't, conservatively mark all regions as executable. + // TODO: Let the `visitOperation` method decide how to propagate + // information to the block arguments. + for (Region ®ion : op->getRegions()) + markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); + } + + // If this op produces no results, it can't produce any constants. + if (op->getNumResults() == 0) + return; + + // If all of the results of this operation are already resolved, bail out + // early. + auto isAtFixpointFn = [&](Value value) { return isAtFixpoint(value); }; + if (llvm::all_of(op->getResults(), isAtFixpointFn)) + return; + + // Visit the current operation. + if (analysis.visitOperation(op, operandLattices) == ChangeResult::Change) + opWorklist.push_back(op); + + // `visitOperation` is required to define all of the result lattices. + assert(llvm::none_of( + op->getResults(), + [&](Value value) { + return analysis.getLatticeElement(value).isUninitialized(); + }) && + "expected `visitOperation` to define all result lattices"); +} + +void ForwardDataFlowSolver::visitCallableOperation(Operation *op) { + // Mark the regions as executable. If we aren't tracking lattice state for + // this callable, mark all of the region arguments as having reached a + // fixpoint. + bool isTrackingLatticeState = callableLatticeState.count(op); + for (Region ®ion : op->getRegions()) + markEntryBlockExecutable(®ion, !isTrackingLatticeState); + + // TODO: Add support for non-symbol callables when necessary. If the callable + // has non-call uses we would mark as having reached pessimistic fixpoint, + // otherwise allow for propagating the return values out. + markAllPessimisticFixpoint(op, op->getResults()); +} + +void ForwardDataFlowSolver::visitCallOperation(CallOpInterface op) { + ResultRange callResults = op->getResults(); + + // Resolve the callable operation for this call. + Operation *callableOp = nullptr; + if (Value callableValue = op.getCallableForCallee().dyn_cast()) + callableOp = callableValue.getDefiningOp(); + else + callableOp = callToSymbolCallable.lookup(op); + + // The callable of this call can't be resolved, mark any results overdefined. + if (!callableOp) + return markAllPessimisticFixpoint(op, callResults); + + // If this callable is tracking state, merge the argument operands with the + // arguments of the callable. + auto callableLatticeIt = callableLatticeState.find(callableOp); + if (callableLatticeIt == callableLatticeState.end()) + return markAllPessimisticFixpoint(op, callResults); + + OperandRange callOperands = op.getArgOperands(); + auto callableArgs = callableLatticeIt->second.getCallableArguments(); + for (auto it : llvm::zip(callOperands, callableArgs)) { + BlockArgument callableArg = std::get<1>(it); + AbstractLatticeElement &argValue = analysis.getLatticeElement(callableArg); + AbstractLatticeElement &operandValue = + analysis.getLatticeElement(std::get<0>(it)); + if (argValue.join(operandValue) == ChangeResult::Change) + visitUsers(callableArg); + } + + // Visit the callable. + visitCallableOperation(callableOp); + + // Merge in the lattice state for the callable results as well. + auto callableResults = callableLatticeIt->second.getResultLatticeElements(); + for (auto it : llvm::zip(callResults, callableResults)) + join(/*owner=*/op, + /*to=*/analysis.getLatticeElement(std::get<0>(it)), + /*from=*/std::get<1>(it)); +} + +void ForwardDataFlowSolver::visitRegionBranchOperation( + RegionBranchOpInterface branch, + ArrayRef operandLattices) { + // Check to see which regions are executable. + SmallVector successors; + analysis.getSuccessorsForOperands(branch, /*sourceIndex=*/llvm::None, + operandLattices, successors); + + // If the interface identified that no region will be executed. Mark + // any results of this operation as overdefined, as we can't reason about + // them. + // TODO: If we had an interface to detect pass through operands, we could + // resolve some results based on the lattice state of the operands. We could + // also allow for the parent operation to have itself as a region successor. + if (successors.empty()) + return markAllPessimisticFixpoint(branch, branch->getResults()); + return visitRegionSuccessors( + branch, successors, [&](Optional index) { + assert(index && "expected valid region index"); + return branch.getSuccessorEntryOperands(*index); + }); +} + +void ForwardDataFlowSolver::visitRegionSuccessors( + Operation *parentOp, ArrayRef regionSuccessors, + function_ref)> getInputsForRegion) { + for (const RegionSuccessor &it : regionSuccessors) { + Region *region = it.getSuccessor(); + ValueRange succArgs = it.getSuccessorInputs(); + + // Check to see if this is the parent operation. + if (!region) { + ResultRange results = parentOp->getResults(); + if (llvm::all_of(results, [&](Value res) { return isAtFixpoint(res); })) + continue; + + // Mark the results outside of the input range as having reached the + // pessimistic fixpoint. + // TODO: This isn't exactly ideal. There may be situations in which a + // region operation can provide information for certain results that + // aren't part of the control flow. + if (succArgs.size() != results.size()) { + opWorklist.push_back(parentOp); + if (succArgs.empty()) + return markAllPessimisticFixpoint(results); + + unsigned firstResIdx = succArgs[0].cast().getResultNumber(); + markAllPessimisticFixpoint(results.take_front(firstResIdx)); + markAllPessimisticFixpoint( + results.drop_front(firstResIdx + succArgs.size())); + } + + // Update the lattice for any operation results. + OperandRange operands = getInputsForRegion(/*index=*/llvm::None); + for (auto it : llvm::zip(succArgs, operands)) + join(parentOp, analysis.getLatticeElement(std::get<0>(it)), + analysis.getLatticeElement(std::get<1>(it))); + return; + } + assert(!region->empty() && "expected region to be non-empty"); + Block *entryBlock = ®ion->front(); + markBlockExecutable(entryBlock); + + // If all of the arguments have already reached a fixpoint, the arguments + // have already been fully resolved. + Block::BlockArgListType arguments = entryBlock->getArguments(); + if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); })) + continue; + + // Mark any arguments that do not receive inputs as having reached a + // pessimistic fixpoint, we won't be able to discern if they are constant. + // TODO: This isn't exactly ideal. There may be situations in which a + // region operation can provide information for certain results that + // aren't part of the control flow. + if (succArgs.size() != arguments.size()) { + if (succArgs.empty()) { + markAllPessimisticFixpoint(arguments); + continue; + } + + unsigned firstArgIdx = succArgs[0].cast().getArgNumber(); + markAllPessimisticFixpointAndVisitUsers( + arguments.take_front(firstArgIdx)); + markAllPessimisticFixpointAndVisitUsers( + arguments.drop_front(firstArgIdx + succArgs.size())); + } + + // Update the lattice of arguments that have inputs from the predecessor. + OperandRange succOperands = getInputsForRegion(region->getRegionNumber()); + for (auto it : llvm::zip(succArgs, succOperands)) { + AbstractLatticeElement &argValue = + analysis.getLatticeElement(std::get<0>(it)); + AbstractLatticeElement &operandValue = + analysis.getLatticeElement(std::get<1>(it)); + if (argValue.join(operandValue) == ChangeResult::Change) + visitUsers(std::get<0>(it)); + } + } +} + +void ForwardDataFlowSolver::visitTerminatorOperation( + Operation *op, ArrayRef operandLattices) { + // If this operation has no successors, we treat it as an exiting terminator. + if (op->getNumSuccessors() == 0) { + Region *parentRegion = op->getParentRegion(); + Operation *parentOp = parentRegion->getParentOp(); + + // Check to see if this is a terminator for a callable region. + if (isa(parentOp)) + return visitCallableTerminatorOperation(parentOp, op, operandLattices); + + // Otherwise, check to see if the parent tracks region control flow. + auto regionInterface = dyn_cast(parentOp); + if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) + return; + + // Query the set of successors of the current region using the current + // optimistic lattice state. + SmallVector regionSuccessors; + analysis.getSuccessorsForOperands(regionInterface, + parentRegion->getRegionNumber(), + operandLattices, regionSuccessors); + if (regionSuccessors.empty()) + return; + + // If this terminator is not "region-like", conservatively mark all of the + // successor values as having reached the pessimistic fixpoint. + if (!op->hasTrait()) { + for (auto &it : regionSuccessors) + markAllPessimisticFixpointAndVisitUsers(it.getSuccessorInputs()); + return; + } + + // Otherwise, propagate the operand lattice states to the successors. + OperandRange operands = op->getOperands(); + return visitRegionSuccessors(parentOp, regionSuccessors, + [&](Optional) { return operands; }); + } + + // Try to resolve to a specific set of successors with the current optimistic + // lattice state. + Block *block = op->getBlock(); + if (auto branch = dyn_cast(op)) { + SmallVector successors; + if (succeeded(analysis.getSuccessorsForOperands(branch, operandLattices, + successors))) { + for (Block *succ : successors) + markEdgeExecutable(block, succ); + return; + } + } + + // Otherwise, conservatively treat all edges as executable. + for (Block *succ : op->getSuccessors()) + markEdgeExecutable(block, succ); +} + +void ForwardDataFlowSolver::visitCallableTerminatorOperation( + Operation *callable, Operation *terminator, + ArrayRef operandLattices) { + // If there are no exiting values, we have nothing to track. + if (terminator->getNumOperands() == 0) + return; + + // If this callable isn't tracking any lattice state there is nothing to do. + auto latticeIt = callableLatticeState.find(callable); + if (latticeIt == callableLatticeState.end()) + return; + assert(callable->getNumResults() == 0 && "expected symbol callable"); + + // If this terminator is not "return-like", conservatively mark all of the + // call-site results as having reached the pessimistic fixpoint. + auto callableResultLattices = latticeIt->second.getResultLatticeElements(); + if (!terminator->hasTrait()) { + for (auto &it : callableResultLattices) + it.markPessimisticFixpoint(); + for (Operation *call : latticeIt->second.getSymbolCalls()) + markAllPessimisticFixpoint(call, call->getResults()); + return; + } + + // Merge the lattice state for terminator operands into the results. + ChangeResult result = ChangeResult::NoChange; + for (auto it : llvm::zip(operandLattices, callableResultLattices)) + result |= std::get<1>(it).join(*std::get<0>(it)); + if (result == ChangeResult::NoChange) + return; + + // If any of the result lattices changed, update the callers. + for (Operation *call : latticeIt->second.getSymbolCalls()) + for (auto it : llvm::zip(call->getResults(), callableResultLattices)) + join(call, analysis.getLatticeElement(std::get<0>(it)), std::get<1>(it)); +} + +void ForwardDataFlowSolver::visitBlock(Block *block) { + // If the block is not the entry block we need to compute the lattice state + // for the block arguments. Entry block argument lattices are computed + // elsewhere, such as when visiting the parent operation. + if (!block->isEntryBlock()) { + for (int i : llvm::seq(0, block->getNumArguments())) + visitBlockArgument(block, i); + } + + // Visit all of the operations within the block. + for (Operation &op : *block) + visitOperation(&op); +} + +void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) { + BlockArgument arg = block->getArgument(i); + AbstractLatticeElement &argLattice = analysis.getLatticeElement(arg); + if (argLattice.isAtFixpoint()) + return; + + ChangeResult updatedLattice = ChangeResult::NoChange; + for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + Block *pred = *it; + + // We only care about this predecessor if it is going to execute. + if (!isEdgeExecutable(pred, block)) + continue; + + // Try to get the operand forwarded by the predecessor. If we can't reason + // about the terminator of the predecessor, mark as having reached a + // fixpoint. + Optional branchOperands; + if (auto branch = dyn_cast(pred->getTerminator())) + branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex()); + if (!branchOperands) { + updatedLattice |= argLattice.markPessimisticFixpoint(); + break; + } + + // If the operand hasn't been resolved, it is uninitialized and can merge + // with anything. + AbstractLatticeElement *operandLattice = + analysis.lookupLatticeElement((*branchOperands)[i]); + if (!operandLattice) + continue; + + // Otherwise, join the operand lattice into the argument lattice. + updatedLattice |= argLattice.join(*operandLattice); + if (argLattice.isAtFixpoint()) + break; + } + + // If the lattice changed, visit users of the argument. + if (updatedLattice == ChangeResult::Change) + visitUsers(arg); +} + +ChangeResult +ForwardDataFlowSolver::markEntryBlockExecutable(Region *region, + bool markPessimisticFixpoint) { + if (!region->empty()) { + if (markPessimisticFixpoint) + markAllPessimisticFixpoint(region->front().getArguments()); + return markBlockExecutable(®ion->front()); + } + return ChangeResult::NoChange; +} + +ChangeResult ForwardDataFlowSolver::markBlockExecutable(Block *block) { + bool marked = executableBlocks.insert(block).second; + if (marked) + blockWorklist.push_back(block); + return marked ? ChangeResult::Change : ChangeResult::NoChange; +} + +bool ForwardDataFlowSolver::isBlockExecutable(Block *block) const { + return executableBlocks.count(block); +} + +void ForwardDataFlowSolver::markEdgeExecutable(Block *from, Block *to) { + if (!executableEdges.insert(std::make_pair(from, to)).second) + return; + + // Mark the destination as executable, and reprocess its arguments if it was + // already executable. + if (markBlockExecutable(to) == ChangeResult::NoChange) { + for (int i : llvm::seq(0, to->getNumArguments())) + visitBlockArgument(to, i); + } +} + +bool ForwardDataFlowSolver::isEdgeExecutable(Block *from, Block *to) const { + return executableEdges.count(std::make_pair(from, to)); +} + +void ForwardDataFlowSolver::markPessimisticFixpoint(Value value) { + analysis.getLatticeElement(value).markPessimisticFixpoint(); +} + +bool ForwardDataFlowSolver::isAtFixpoint(Value value) const { + if (auto *lattice = analysis.lookupLatticeElement(value)) + return lattice->isAtFixpoint(); + return false; +} + +void ForwardDataFlowSolver::join(Operation *owner, AbstractLatticeElement &to, + const AbstractLatticeElement &from) { + if (to.join(from) == ChangeResult::Change) + opWorklist.push_back(owner); +} + +//===----------------------------------------------------------------------===// +// AbstractLatticeElement +//===----------------------------------------------------------------------===// + +AbstractLatticeElement::~AbstractLatticeElement() {} + +//===----------------------------------------------------------------------===// +// ForwardDataFlowAnalysisBase +//===----------------------------------------------------------------------===// + +ForwardDataFlowAnalysisBase::~ForwardDataFlowAnalysisBase() {} + +AbstractLatticeElement & +ForwardDataFlowAnalysisBase::getLatticeElement(Value value) { + AbstractLatticeElement *&latticeValue = latticeValues[value]; + if (!latticeValue) + latticeValue = createLatticeElement(value); + return *latticeValue; +} + +AbstractLatticeElement * +ForwardDataFlowAnalysisBase::lookupLatticeElement(Value value) { + return latticeValues.lookup(value); +} + +void ForwardDataFlowAnalysisBase::run(Operation *topLevelOp) { + // Run the main dataflow solver. + ForwardDataFlowSolver solver(*this, topLevelOp); + solver.solve(); + + // Any values that are still uninitialized now go to a pessimistic fixpoint, + // otherwise we assume an optimistic fixpoint has been reached. + for (auto &it : latticeValues) + if (it.second->isUninitialized()) + it.second->markPessimisticFixpoint(); + else + it.second->markOptimisticFixpoint(); +} diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Analysis/DataFlowAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -25,326 +26,173 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// SCCP Analysis +//===----------------------------------------------------------------------===// + namespace { -/// This class represents a single lattice value. A lattive value corresponds to -/// the various different states that a value in the SCCP dataflow analysis can -/// take. See 'Kind' below for more details on the different states a value can -/// take. -class LatticeValue { - enum Kind { - /// A value with a yet to be determined value. This state may be changed to - /// anything. - Unknown, - - /// A value that is known to be a constant. This state may be changed to - /// overdefined. - Constant, - - /// A value that cannot statically be determined to be a constant. This - /// state cannot be changed. - Overdefined - }; +struct SCCPLatticeValue { + SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr) + : constant(constant), constantDialect(dialect) {} -public: - /// Initialize a lattice value with "Unknown". - LatticeValue() - : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {} - /// Initialize a lattice value with a constant. - LatticeValue(Attribute attr, Dialect *dialect) - : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {} - - /// Returns true if this lattice value is unknown. - bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; } - - /// Mark the lattice value as overdefined. - void markOverdefined() { - constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined); - constantDialect = nullptr; + /// The pessimistic state of SCCP is non-constant. + static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) { + return SCCPLatticeValue(); } - - /// Returns true if the lattice is overdefined. - bool isOverdefined() const { - return constantAndTag.getInt() == Kind::Overdefined; + static SCCPLatticeValue getPessimisticValueState(Value value) { + return SCCPLatticeValue(); } - /// Mark the lattice value as constant. - void markConstant(Attribute value, Dialect *dialect) { - constantAndTag.setPointerAndInt(value, Kind::Constant); - constantDialect = dialect; + /// Equivalence for SCCP only accounts for the constant, not the originating + /// dialect. + bool operator==(const SCCPLatticeValue &rhs) const { + return constant == rhs.constant; } - /// If this lattice is constant, return the constant. Returns nullptr - /// otherwise. - Attribute getConstant() const { return constantAndTag.getPointer(); } - - /// If this lattice is constant, return the dialect to use when materializing - /// the constant. - Dialect *getConstantDialect() const { - assert(getConstant() && "expected valid constant"); - return constantDialect; - } - - /// Merge in the value of the 'rhs' lattice into this one. Returns true if the - /// lattice value changed. - bool meet(const LatticeValue &rhs) { - // If we are already overdefined, or rhs is unknown, there is nothing to do. - if (isOverdefined() || rhs.isUnknown()) - return false; - // If we are unknown, just take the value of rhs. - if (isUnknown()) { - constantAndTag = rhs.constantAndTag; - constantDialect = rhs.constantDialect; - return true; - } - - // Otherwise, if this value doesn't match rhs go straight to overdefined. - if (constantAndTag != rhs.constantAndTag) { - markOverdefined(); - return true; - } - return false; + /// To join the state of two values, we simply check for equivalence. + static SCCPLatticeValue join(const SCCPLatticeValue &lhs, + const SCCPLatticeValue &rhs) { + return lhs == rhs ? lhs : SCCPLatticeValue(); } -private: - /// The attribute value if this is a constant and the tag for the element - /// kind. - llvm::PointerIntPair constantAndTag; + /// The constant attribute value. + Attribute constant; - /// The dialect the constant originated from. This is only valid if the - /// lattice is a constant. This is not used as part of the key, and is only - /// needed to materialize the held constant if necessary. + /// The dialect the constant originated from. This is not used as part of the + /// key, and is only needed to materialize the held constant if necessary. Dialect *constantDialect; }; -/// This class contains various state used when computing the lattice of a -/// callable operation. -class CallableLatticeState { -public: - /// Build a lattice state with a given callable region, and a specified number - /// of results to be initialized to the default lattice value (Unknown). - CallableLatticeState(Region *callableRegion, unsigned numResults) - : callableArguments(callableRegion->getArguments()), - resultLatticeValues(numResults) {} - - /// Returns the arguments to the callable region. - Block::BlockArgListType getCallableArguments() const { - return callableArguments; - } - - /// Returns the lattice value for the results of the callable region. - MutableArrayRef getResultLatticeValues() { - return resultLatticeValues; - } - - /// Add a call to this callable. This is only used if the callable defines a - /// symbol. - void addSymbolCall(Operation *op) { symbolCalls.push_back(op); } - - /// Return the calls that reference this callable. This is only used - /// if the callable defines a symbol. - ArrayRef getSymbolCalls() const { return symbolCalls; } - -private: - /// The arguments of the callable region. - Block::BlockArgListType callableArguments; +struct SCCPAnalysis : public ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + ~SCCPAnalysis() override = default; - /// The lattice state for each of the results of this region. The return - /// values of the callable aren't SSA values, so we need to track them - /// separately. - SmallVector resultLatticeValues; - - /// The calls referencing this callable if this callable defines a symbol. - /// This removes the need to recompute symbol references during propagation. - /// Value based references are trivial to resolve, so they can be done - /// in-place. - SmallVector symbolCalls; -}; - -/// This class represents the solver for the SCCP analysis. This class acts as -/// the propagation engine for computing which values form constants. -class SCCPSolver { -public: - /// Initialize the solver with the given top-level operation. - SCCPSolver(Operation *op); - - /// Run the solver until it converges. - void solve(); - - /// Rewrite the given regions using the computing analysis. This replaces the - /// uses of all values that have been computed to be constant, and erases as - /// many newly dead operations. - void rewrite(MLIRContext *context, MutableArrayRef regions); - -private: - /// Initialize the set of symbol defining callables that can have their - /// arguments and results tracked. 'op' is the top-level operation that SCCP - /// is operating on. - void initializeSymbolCallables(Operation *op); - - /// Replace the given value with a constant if the corresponding lattice - /// represents a constant. Returns success if the value was replaced, failure - /// otherwise. - LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder, - Value value); - - /// Visit the users of the given IR that reside within executable blocks. - template - void visitUsers(T &value) { - for (Operation *user : value.getUsers()) - if (isBlockExecutable(user->getBlock())) - visitOperation(user); - } + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) final { + // Don't try to simulate the results of a region operation as we can't + // guarantee that folding will be out-of-place. We don't allow in-place + // folds as the desire here is for simulated execution, and not general + // folding. + if (op->getNumRegions()) + return markAllPessimisticFixpoint(op->getResults()); + + SmallVector constantOperands( + llvm::map_range(operands, [](LatticeElement *value) { + return value->getValue().constant; + })); + + // Save the original operands and attributes just in case the operation + // folds in-place. The constant passed in may not correspond to the real + // runtime value, so in-place updates are not allowed. + SmallVector originalOperands(op->getOperands()); + DictionaryAttr originalAttrs = op->getAttrDictionary(); + + // Simulate the result of folding this operation to a constant. If folding + // fails or was an in-place fold, mark the results as overdefined. + SmallVector foldResults; + foldResults.reserve(op->getNumResults()); + if (failed(op->fold(constantOperands, foldResults))) + return markAllPessimisticFixpoint(op->getResults()); + + // If the folding was in-place, mark the results as overdefined and reset + // the operation. We don't allow in-place folds as the desire here is for + // simulated execution, and not general folding. + if (foldResults.empty()) { + op->setOperands(originalOperands); + op->setAttrs(originalAttrs); + return markAllPessimisticFixpoint(op->getResults()); + } - /// Visit the given operation and compute any necessary lattice state. - void visitOperation(Operation *op); - - /// Visit the given call operation and compute any necessary lattice state. - void visitCallOperation(CallOpInterface op); - - /// Visit the given callable operation and compute any necessary lattice - /// state. - void visitCallableOperation(Operation *op); - - /// Visit the given operation, which defines regions, and compute any - /// necessary lattice state. This also resolves the lattice state of both the - /// operation results and any nested regions. - void visitRegionOperation(Operation *op, - ArrayRef constantOperands); - - /// Visit the given set of region successors, computing any necessary lattice - /// state. The provided function returns the input operands to the region at - /// the given index. If the index is 'None', the input operands correspond to - /// the parent operation results. - void visitRegionSuccessors( - Operation *parentOp, ArrayRef regionSuccessors, - function_ref)> getInputsForRegion); - - /// Visit the given terminator operation and compute any necessary lattice - /// state. - void visitTerminatorOperation(Operation *op, - ArrayRef constantOperands); - - /// Visit the given terminator operation that exits a callable region. These - /// are terminators with no CFG successors. - void visitCallableTerminatorOperation(Operation *callable, - Operation *terminator); - - /// Visit the given block and compute any necessary lattice state. - void visitBlock(Block *block); - - /// Visit argument #'i' of the given block and compute any necessary lattice - /// state. - void visitBlockArgument(Block *block, int i); - - /// Mark the entry block of the given region as executable. Returns false if - /// the block was already marked executable. If `markArgsOverdefined` is true, - /// the arguments of the entry block are also set to overdefined. - bool markEntryBlockExecutable(Region *region, bool markArgsOverdefined); - - /// Mark the given block as executable. Returns false if the block was already - /// marked executable. - bool markBlockExecutable(Block *block); - - /// Returns true if the given block is executable. - bool isBlockExecutable(Block *block) const; - - /// Mark the edge between 'from' and 'to' as executable. - void markEdgeExecutable(Block *from, Block *to); - - /// Return true if the edge between 'from' and 'to' is executable. - bool isEdgeExecutable(Block *from, Block *to) const; - - /// Mark the given value as overdefined. This means that we cannot refine a - /// specific constant for this value. - void markOverdefined(Value value); - - /// Mark all of the given values as overdefined. - template - void markAllOverdefined(ValuesT values) { - for (auto value : values) - markOverdefined(value); - } - template - void markAllOverdefined(Operation *op, ValuesT values) { - markAllOverdefined(values); - opWorklist.push_back(op); - } - template - void markAllOverdefinedAndVisitUsers(ValuesT values) { - for (auto value : values) { - auto &lattice = latticeValues[value]; - if (!lattice.isOverdefined()) { - lattice.markOverdefined(); - visitUsers(value); - } + // Merge the fold results into the lattice for this operation. + assert(foldResults.size() == op->getNumResults() && "invalid result size"); + Dialect *dialect = op->getDialect(); + ChangeResult result = ChangeResult::NoChange; + for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { + LatticeElement &lattice = + getLatticeElement(op->getResult(i)); + + // Merge in the result of the fold, either a constant or a value. + OpFoldResult foldResult = foldResults[i]; + if (Attribute attr = foldResult.dyn_cast()) + result |= lattice.join(SCCPLatticeValue(attr, dialect)); + else + result |= lattice.join(getLatticeElement(foldResult.get())); + } + return result; + } + + /// Implementation of `getSuccessorsForOperands` that uses constant operands + /// to potentially remove dead successors. + LogicalResult getSuccessorsForOperands( + BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) final { + SmallVector constantOperands( + llvm::map_range(operands, [](LatticeElement *value) { + return value->getValue().constant; + })); + if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { + successors.push_back(singleSucc); + return success(); } + return failure(); } - /// Returns true if the given value was marked as overdefined. - bool isOverdefined(Value value) const; - - /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' - /// corresponds to the parent operation of 'to'. - void meet(Operation *owner, LatticeValue &to, const LatticeValue &from); - - /// The lattice for each SSA value. - DenseMap latticeValues; - - /// The set of blocks that are known to execute, or are intrinsically live. - SmallPtrSet executableBlocks; - - /// The set of control flow edges that are known to execute. - DenseSet> executableEdges; - - /// A worklist containing blocks that need to be processed. - SmallVector blockWorklist; - - /// A worklist of operations that need to be processed. - SmallVector opWorklist; - - /// The callable operations that have their argument/result state tracked. - DenseMap callableLatticeState; - - /// A map between a call operation and the resolved symbol callable. This - /// avoids re-resolving symbol references during propagation. Value based - /// callables are trivial to resolve, so they can be done in-place. - DenseMap callToSymbolCallable; - - /// A symbol table used for O(1) symbol lookups during simplification. - SymbolTableCollection symbolTable; + /// Implementation of `getSuccessorsForOperands` that uses constant operands + /// to potentially remove dead region successors. + void getSuccessorsForOperands( + RegionBranchOpInterface branch, Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) final { + SmallVector constantOperands( + llvm::map_range(operands, [](LatticeElement *value) { + return value->getValue().constant; + })); + branch.getSuccessorRegions(sourceIndex, constantOperands, successors); + } }; -} // end anonymous namespace +} // namespace -SCCPSolver::SCCPSolver(Operation *op) { - /// Initialize the solver with the regions within this operation. - for (Region ®ion : op->getRegions()) { - // Mark the entry block as executable. The values passed to these regions - // are also invisible, so mark any arguments as overdefined. - markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); - } - initializeSymbolCallables(op); -} +//===----------------------------------------------------------------------===// +// SCCP Rewrites +//===----------------------------------------------------------------------===// -void SCCPSolver::solve() { - while (!blockWorklist.empty() || !opWorklist.empty()) { - // Process any operations in the op worklist. - while (!opWorklist.empty()) - visitUsers(*opWorklist.pop_back_val()); +/// Replace the given value with a constant if the corresponding lattice +/// represents a constant. Returns success if the value was replaced, failure +/// otherwise. +static LogicalResult replaceWithConstant(SCCPAnalysis &analysis, + OpBuilder &builder, + OperationFolder &folder, Value value) { + LatticeElement *lattice = + analysis.lookupLatticeElement(value); + if (!lattice) + return failure(); + SCCPLatticeValue &latticeValue = lattice->getValue(); + if (!latticeValue.constant) + return failure(); - // Process any blocks in the block worklist. - while (!blockWorklist.empty()) - visitBlock(blockWorklist.pop_back_val()); - } + // Attempt to materialize a constant for the given value. + Dialect *dialect = latticeValue.constantDialect; + Value constant = folder.getOrCreateConstant( + builder, dialect, latticeValue.constant, value.getType(), value.getLoc()); + if (!constant) + return failure(); + + value.replaceAllUsesWith(constant); + return success(); } -void SCCPSolver::rewrite(MLIRContext *context, - MutableArrayRef initialRegions) { - SmallVector worklist; +/// Rewrite the given regions using the computing analysis. This replaces the +/// uses of all values that have been computed to be constant, and erases as +/// many newly dead operations. +static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; auto addToWorklist = [&](MutableArrayRef regions) { for (Region ®ion : regions) - for (Block &block : region) - if (isBlockExecutable(&block)) - worklist.push_back(&block); + for (Block &block : llvm::reverse(region)) + worklist.push_back(&block); }; // An operation folder used to create and unique constants. @@ -355,18 +203,14 @@ while (!worklist.empty()) { Block *block = worklist.pop_back_val(); - // Replace any block arguments with constants. - builder.setInsertionPointToStart(block); - for (BlockArgument arg : block->getArguments()) - (void)replaceWithConstant(builder, folder, arg); - for (Operation &op : llvm::make_early_inc_range(*block)) { builder.setInsertionPoint(&op); // Replace any result with constants. bool replacedAll = op.getNumResults() != 0; for (Value res : op.getResults()) - replacedAll &= succeeded(replaceWithConstant(builder, folder, res)); + replacedAll &= + succeeded(replaceWithConstant(analysis, builder, folder, res)); // If all of the results of the operation were replaced, try to erase // the operation completely. @@ -379,532 +223,14 @@ // Add any the regions of this operation to the worklist. addToWorklist(op.getRegions()); } - } -} - -void SCCPSolver::initializeSymbolCallables(Operation *op) { - // Initialize the set of symbol callables that can have their state tracked. - // This tracks which symbol callable operations we can propagate within and - // out of. - auto walkFn = [&](Operation *symTable, bool allUsesVisible) { - Region &symbolTableRegion = symTable->getRegion(0); - Block *symbolTableBlock = &symbolTableRegion.front(); - for (auto callable : symbolTableBlock->getOps()) { - // We won't be able to track external callables. - Region *callableRegion = callable.getCallableRegion(); - if (!callableRegion) - continue; - // We only care about symbol defining callables here. - auto symbol = dyn_cast(callable.getOperation()); - if (!symbol) - continue; - callableLatticeState.try_emplace(callable, callableRegion, - callable.getCallableResults().size()); - - // If not all of the uses of this symbol are visible, we can't track the - // state of the arguments. - if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { - for (Region ®ion : callable->getRegions()) - markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); - } - } - if (callableLatticeState.empty()) - return; - - // After computing the valid callables, walk any symbol uses to check - // for non-call references. We won't be able to track the lattice state - // for arguments to these callables, as we can't guarantee that we can see - // all of its calls. - Optional uses = - SymbolTable::getSymbolUses(&symbolTableRegion); - if (!uses) { - // If we couldn't gather the symbol uses, conservatively assume that - // we can't track information for any nested symbols. - op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); }); - return; - } - - for (const SymbolTable::SymbolUse &use : *uses) { - // If the use is a call, track it to avoid the need to recompute the - // reference later. - if (auto callOp = dyn_cast(use.getUser())) { - Operation *symCallable = callOp.resolveCallable(&symbolTable); - auto callableLatticeIt = callableLatticeState.find(symCallable); - if (callableLatticeIt != callableLatticeState.end()) { - callToSymbolCallable.try_emplace(callOp, symCallable); - - // We only need to record the call in the lattice if it produces any - // values. - if (callOp->getNumResults()) - callableLatticeIt->second.addSymbolCall(callOp); - } - continue; - } - // This use isn't a call, so don't we know all of the callers. - auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef()); - auto it = callableLatticeState.find(symbol); - if (it != callableLatticeState.end()) { - for (Region ®ion : it->first->getRegions()) - markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); - } - } - }; - SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), - walkFn); -} - -LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder, - OperationFolder &folder, - Value value) { - auto it = latticeValues.find(value); - auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant(); - if (!attr) - return failure(); - - // Attempt to materialize a constant for the given value. - Dialect *dialect = it->second.getConstantDialect(); - Value constant = folder.getOrCreateConstant(builder, dialect, attr, - value.getType(), value.getLoc()); - if (!constant) - return failure(); - - value.replaceAllUsesWith(constant); - latticeValues.erase(it); - return success(); -} -void SCCPSolver::visitOperation(Operation *op) { - // Collect all of the constant operands feeding into this operation. If any - // are not ready to be resolved, bail out and wait for them to resolve. - SmallVector operandConstants; - operandConstants.reserve(op->getNumOperands()); - for (Value operand : op->getOperands()) { - // Make sure all of the operands are resolved first. - auto &operandLattice = latticeValues[operand]; - if (operandLattice.isUnknown()) - return; - operandConstants.push_back(operandLattice.getConstant()); - } - - // If this is a terminator operation, process any control flow lattice state. - if (op->hasTrait()) - visitTerminatorOperation(op, operandConstants); - - // Process call operations. The call visitor processes result values, so we - // can exit afterwards. - if (CallOpInterface call = dyn_cast(op)) - return visitCallOperation(call); - - // Process callable operations. These are specially handled region operations - // that track dataflow via calls. - if (isa(op)) { - // If this callable has a tracked lattice state, it will be visited by calls - // that reference it instead. This way, we don't assume that it is - // executable unless there is a proper reference to it. - if (callableLatticeState.count(op)) - return; - return visitCallableOperation(op); - } - - // Process region holding operations. The region visitor processes result - // values, so we can exit afterwards. - if (op->getNumRegions()) - return visitRegionOperation(op, operandConstants); - - // If this op produces no results, it can't produce any constants. - if (op->getNumResults() == 0) - return; - - // If all of the results of this operation are already overdefined, bail out - // early. - auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); }; - if (llvm::all_of(op->getResults(), isOverdefinedFn)) - return; - - // Save the original operands and attributes just in case the operation folds - // in-place. The constant passed in may not correspond to the real runtime - // value, so in-place updates are not allowed. - SmallVector originalOperands(op->getOperands()); - DictionaryAttr originalAttrs = op->getAttrDictionary(); - - // Simulate the result of folding this operation to a constant. If folding - // fails or was an in-place fold, mark the results as overdefined. - SmallVector foldResults; - foldResults.reserve(op->getNumResults()); - if (failed(op->fold(operandConstants, foldResults))) - return markAllOverdefined(op, op->getResults()); - - // If the folding was in-place, mark the results as overdefined and reset the - // operation. We don't allow in-place folds as the desire here is for - // simulated execution, and not general folding. - if (foldResults.empty()) { - op->setOperands(originalOperands); - op->setAttrs(originalAttrs); - return markAllOverdefined(op, op->getResults()); - } - - // Merge the fold results into the lattice for this operation. - assert(foldResults.size() == op->getNumResults() && "invalid result size"); - Dialect *opDialect = op->getDialect(); - for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { - LatticeValue &resultLattice = latticeValues[op->getResult(i)]; - - // Merge in the result of the fold, either a constant or a value. - OpFoldResult foldResult = foldResults[i]; - if (Attribute foldAttr = foldResult.dyn_cast()) - meet(op, resultLattice, LatticeValue(foldAttr, opDialect)); - else - meet(op, resultLattice, latticeValues[foldResult.get()]); - } -} - -void SCCPSolver::visitCallableOperation(Operation *op) { - // Mark the regions as executable. If we aren't tracking lattice state for - // this callable, mark all of the region arguments as overdefined. - bool isTrackingLatticeState = callableLatticeState.count(op); - for (Region ®ion : op->getRegions()) - markEntryBlockExecutable(®ion, !isTrackingLatticeState); - - // TODO: Add support for non-symbol callables when necessary. If the callable - // has non-call uses we would mark overdefined, otherwise allow for - // propagating the return values out. - markAllOverdefined(op, op->getResults()); -} - -void SCCPSolver::visitCallOperation(CallOpInterface op) { - ResultRange callResults = op->getResults(); - - // Resolve the callable operation for this call. - Operation *callableOp = nullptr; - if (Value callableValue = op.getCallableForCallee().dyn_cast()) - callableOp = callableValue.getDefiningOp(); - else - callableOp = callToSymbolCallable.lookup(op); - - // The callable of this call can't be resolved, mark any results overdefined. - if (!callableOp) - return markAllOverdefined(op, callResults); - - // If this callable is tracking state, merge the argument operands with the - // arguments of the callable. - auto callableLatticeIt = callableLatticeState.find(callableOp); - if (callableLatticeIt == callableLatticeState.end()) - return markAllOverdefined(op, callResults); - - OperandRange callOperands = op.getArgOperands(); - auto callableArgs = callableLatticeIt->second.getCallableArguments(); - for (auto it : llvm::zip(callOperands, callableArgs)) { - BlockArgument callableArg = std::get<1>(it); - if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)])) - visitUsers(callableArg); - } - - // Visit the callable. - visitCallableOperation(callableOp); - - // Merge in the lattice state for the callable results as well. - auto callableResults = callableLatticeIt->second.getResultLatticeValues(); - for (auto it : llvm::zip(callResults, callableResults)) - meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)], - /*from=*/std::get<1>(it)); -} - -void SCCPSolver::visitRegionOperation(Operation *op, - ArrayRef constantOperands) { - // Check to see if we can reason about the internal control flow of this - // region operation. - auto regionInterface = dyn_cast(op); - if (!regionInterface) { - // If we can't, conservatively mark all regions as executable. - for (Region ®ion : op->getRegions()) - markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); - - // Don't try to simulate the results of a region operation as we can't - // guarantee that folding will be out-of-place. We don't allow in-place - // folds as the desire here is for simulated execution, and not general - // folding. - return markAllOverdefined(op, op->getResults()); - } - - // Check to see which regions are executable. - SmallVector successors; - regionInterface.getSuccessorRegions(/*index=*/llvm::None, constantOperands, - successors); - - // If the interface identified that no region will be executed. Mark - // any results of this operation as overdefined, as we can't reason about - // them. - // TODO: If we had an interface to detect pass through operands, we could - // resolve some results based on the lattice state of the operands. We could - // also allow for the parent operation to have itself as a region successor. - if (successors.empty()) - return markAllOverdefined(op, op->getResults()); - return visitRegionSuccessors(op, successors, [&](Optional index) { - assert(index && "expected valid region index"); - return regionInterface.getSuccessorEntryOperands(*index); - }); -} - -void SCCPSolver::visitRegionSuccessors( - Operation *parentOp, ArrayRef regionSuccessors, - function_ref)> getInputsForRegion) { - for (const RegionSuccessor &it : regionSuccessors) { - Region *region = it.getSuccessor(); - ValueRange succArgs = it.getSuccessorInputs(); - - // Check to see if this is the parent operation. - if (!region) { - ResultRange results = parentOp->getResults(); - if (llvm::all_of(results, [&](Value res) { return isOverdefined(res); })) - continue; - - // Mark the results outside of the input range as overdefined. - if (succArgs.size() != results.size()) { - opWorklist.push_back(parentOp); - if (succArgs.empty()) - return markAllOverdefined(results); - - unsigned firstResIdx = succArgs[0].cast().getResultNumber(); - markAllOverdefined(results.take_front(firstResIdx)); - markAllOverdefined(results.drop_front(firstResIdx + succArgs.size())); - } - - // Update the lattice for any operation results. - OperandRange operands = getInputsForRegion(/*index=*/llvm::None); - for (auto it : llvm::zip(succArgs, operands)) - meet(parentOp, latticeValues[std::get<0>(it)], - latticeValues[std::get<1>(it)]); - return; - } - assert(!region->empty() && "expected region to be non-empty"); - Block *entryBlock = ®ion->front(); - markBlockExecutable(entryBlock); - - // If all of the arguments are already overdefined, the arguments have - // already been fully resolved. - auto arguments = entryBlock->getArguments(); - if (llvm::all_of(arguments, [&](Value arg) { return isOverdefined(arg); })) - continue; - - // Mark any arguments that do not receive inputs as overdefined, we won't be - // able to discern if they are constant. - if (succArgs.size() != arguments.size()) { - if (succArgs.empty()) { - markAllOverdefined(arguments); - continue; - } - - unsigned firstArgIdx = succArgs[0].cast().getArgNumber(); - markAllOverdefinedAndVisitUsers(arguments.take_front(firstArgIdx)); - markAllOverdefinedAndVisitUsers( - arguments.drop_front(firstArgIdx + succArgs.size())); - } - - // Update the lattice for arguments that have inputs from the predecessor. - OperandRange succOperands = getInputsForRegion(region->getRegionNumber()); - for (auto it : llvm::zip(succArgs, succOperands)) { - LatticeValue &argLattice = latticeValues[std::get<0>(it)]; - if (argLattice.meet(latticeValues[std::get<1>(it)])) - visitUsers(std::get<0>(it)); - } - } -} - -void SCCPSolver::visitTerminatorOperation( - Operation *op, ArrayRef constantOperands) { - // If this operation has no successors, we treat it as an exiting terminator. - if (op->getNumSuccessors() == 0) { - Region *parentRegion = op->getParentRegion(); - Operation *parentOp = parentRegion->getParentOp(); - - // Check to see if this is a terminator for a callable region. - if (isa(parentOp)) - return visitCallableTerminatorOperation(parentOp, op); - - // Otherwise, check to see if the parent tracks region control flow. - auto regionInterface = dyn_cast(parentOp); - if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) - return; - - // Query the set of successors from the current region. - SmallVector regionSuccessors; - regionInterface.getSuccessorRegions(parentRegion->getRegionNumber(), - constantOperands, regionSuccessors); - if (regionSuccessors.empty()) - return; - - // If this terminator is not "region-like", conservatively mark all of the - // successor values as overdefined. - if (!op->hasTrait()) { - for (auto &it : regionSuccessors) - markAllOverdefinedAndVisitUsers(it.getSuccessorInputs()); - return; - } - - // Otherwise, propagate the operand lattice states to each of the - // successors. - OperandRange operands = op->getOperands(); - return visitRegionSuccessors(parentOp, regionSuccessors, - [&](Optional) { return operands; }); - } - - // Try to resolve to a specific successor with the constant operands. - if (auto branch = dyn_cast(op)) { - if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { - markEdgeExecutable(op->getBlock(), singleSucc); - return; - } - } - - // Otherwise, conservatively treat all edges as executable. - Block *block = op->getBlock(); - for (Block *succ : op->getSuccessors()) - markEdgeExecutable(block, succ); -} - -void SCCPSolver::visitCallableTerminatorOperation(Operation *callable, - Operation *terminator) { - // If there are no exiting values, we have nothing to track. - if (terminator->getNumOperands() == 0) - return; - - // If this callable isn't tracking any lattice state there is nothing to do. - auto latticeIt = callableLatticeState.find(callable); - if (latticeIt == callableLatticeState.end()) - return; - assert(callable->getNumResults() == 0 && "expected symbol callable"); - - // If this terminator is not "return-like", conservatively mark all of the - // call-site results as overdefined. - auto callableResultLattices = latticeIt->second.getResultLatticeValues(); - if (!terminator->hasTrait()) { - for (auto &it : callableResultLattices) - it.markOverdefined(); - for (Operation *call : latticeIt->second.getSymbolCalls()) - markAllOverdefined(call, call->getResults()); - return; - } - - // Merge the terminator operands into the results. - bool anyChanged = false; - for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices)) - anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]); - if (!anyChanged) - return; - - // If any of the result lattices changed, update the callers. - for (Operation *call : latticeIt->second.getSymbolCalls()) - for (auto it : llvm::zip(call->getResults(), callableResultLattices)) - meet(call, latticeValues[std::get<0>(it)], std::get<1>(it)); -} - -void SCCPSolver::visitBlock(Block *block) { - // If the block is not the entry block we need to compute the lattice state - // for the block arguments. Entry block argument lattices are computed - // elsewhere, such as when visiting the parent operation. - if (!block->isEntryBlock()) { - for (int i : llvm::seq(0, block->getNumArguments())) - visitBlockArgument(block, i); - } - - // Visit all of the operations within the block. - for (Operation &op : *block) - visitOperation(&op); -} - -void SCCPSolver::visitBlockArgument(Block *block, int i) { - BlockArgument arg = block->getArgument(i); - LatticeValue &argLattice = latticeValues[arg]; - if (argLattice.isOverdefined()) - return; - - bool updatedLattice = false; - for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { - Block *pred = *it; - - // We only care about this predecessor if it is going to execute. - if (!isEdgeExecutable(pred, block)) - continue; - - // Try to get the operand forwarded by the predecessor. If we can't reason - // about the terminator of the predecessor, mark overdefined. - Optional branchOperands; - if (auto branch = dyn_cast(pred->getTerminator())) - branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex()); - if (!branchOperands) { - updatedLattice = true; - argLattice.markOverdefined(); - break; - } - - // If the operand hasn't been resolved, it is unknown which can merge with - // anything. - auto operandLattice = latticeValues.find((*branchOperands)[i]); - if (operandLattice == latticeValues.end()) - continue; - - // Otherwise, meet the two lattice values. - updatedLattice |= argLattice.meet(operandLattice->second); - if (argLattice.isOverdefined()) - break; - } - - // If the lattice was updated, visit any executable users of the argument. - if (updatedLattice) - visitUsers(arg); -} - -bool SCCPSolver::markEntryBlockExecutable(Region *region, - bool markArgsOverdefined) { - if (!region->empty()) { - if (markArgsOverdefined) - markAllOverdefined(region->front().getArguments()); - return markBlockExecutable(®ion->front()); - } - return false; -} - -bool SCCPSolver::markBlockExecutable(Block *block) { - bool marked = executableBlocks.insert(block).second; - if (marked) - blockWorklist.push_back(block); - return marked; -} - -bool SCCPSolver::isBlockExecutable(Block *block) const { - return executableBlocks.count(block); -} - -void SCCPSolver::markEdgeExecutable(Block *from, Block *to) { - if (!executableEdges.insert(std::make_pair(from, to)).second) - return; - // Mark the destination as executable, and reprocess its arguments if it was - // already executable. - if (!markBlockExecutable(to)) { - for (int i : llvm::seq(0, to->getNumArguments())) - visitBlockArgument(to, i); + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + (void)replaceWithConstant(analysis, builder, folder, arg); } } -bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const { - return executableEdges.count(std::make_pair(from, to)); -} - -void SCCPSolver::markOverdefined(Value value) { - latticeValues[value].markOverdefined(); -} - -bool SCCPSolver::isOverdefined(Value value) const { - auto it = latticeValues.find(value); - return it != latticeValues.end() && it->second.isOverdefined(); -} - -void SCCPSolver::meet(Operation *owner, LatticeValue &to, - const LatticeValue &from) { - if (to.meet(from)) - opWorklist.push_back(owner); -} - //===----------------------------------------------------------------------===// // SCCP Pass //===----------------------------------------------------------------------===// @@ -918,12 +244,9 @@ void SCCP::runOnOperation() { Operation *op = getOperation(); - // Solve for SCCP constraints within nested regions. - SCCPSolver solver(op); - solver.solve(); - - // Cleanup any operations using the solver analysis. - solver.rewrite(&getContext(), op->getRegions()); + SCCPAnalysis analysis(op->getContext()); + analysis.run(op); + rewrite(analysis, op->getContext(), op->getRegions()); } std::unique_ptr mlir::createSCCPPass() {