diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h @@ -60,6 +60,24 @@ Dialect *dialect; }; +//===----------------------------------------------------------------------===// +// SparseConstantPropagation +//===----------------------------------------------------------------------===// + +/// This analysis implements sparse constant propagation, which attempts to +/// determine constant-valued results for operations using constant-valued +/// operands, by speculatively folding operations. When combined with dead-code +/// analysis, this becomes sparse conditional constant propagation (SCCP). +class SparseConstantPropagation + : public SparseDataFlowAnalysis> { +public: + using SparseDataFlowAnalysis::SparseDataFlowAnalysis; + + void visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + } // end namespace dataflow } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -89,6 +89,9 @@ /// the predecessor to its entry block, and the exiting terminator or a callable /// operation can be the predecessor of the call operation. /// +/// The state can optionally contain information about which values are +/// propagated from each predecessor to the successor point. +/// /// The state can indicate that it is underdefined, meaning that not all live /// control-flow predecessors can be known. class PredecessorState : public AnalysisState { @@ -118,12 +121,17 @@ return knownPredecessors.getArrayRef(); } - /// Add a known predecessor. - ChangeResult join(Operation *predecessor) { - return knownPredecessors.insert(predecessor) ? ChangeResult::Change - : ChangeResult::NoChange; + /// Get the successor inputs from a predecessor. + ValueRange getSuccessorInputs(Operation *predecessor) const { + return successorInputs.lookup(predecessor); } + /// Add a known predecessor. + ChangeResult join(Operation *predecessor); + + /// Add a known predecessor with successor inputs. + ChangeResult join(Operation *predecessor, ValueRange inputs); + private: /// Whether all predecessors are known. Optimistically assume that we know /// all predecessors. @@ -133,6 +141,9 @@ SetVector, SmallPtrSet> knownPredecessors; + + /// The successor inputs when branching from a given predecessor. + DenseMap successorInputs; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -16,6 +16,7 @@ #define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/SmallPtrSet.h" namespace mlir { @@ -179,6 +180,137 @@ Optional optimisticValue; }; +//===----------------------------------------------------------------------===// +// AbstractSparseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// Base class for sparse (forward) data-flow analyses. A sparse analysis +/// implements a transfer function on operations from the lattices of the +/// operands to the lattices of the results. This analysis will propagate +/// lattices across control-flow edges and the callgraph using liveness +/// information. +class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis { +public: + /// Initialize the analysis by visiting every owner of an SSA value: all + /// operations and blocks. + LogicalResult initialize(Operation *top) override; + + /// Visit a program point. If this is a block and all control-flow + /// predecessors or callsites are known, then the arguments lattices are + /// propagated from them. If this is a call operation or an operation with + /// region control-flow, then its result lattices are set accordingly. + /// Otherwise, the operation transfer function is invoked. + LogicalResult visit(ProgramPoint point) override; + +protected: + explicit AbstractSparseDataFlowAnalysis(DataFlowSolver &solver); + + /// The operation transfer function. Given the operand lattices, this + /// function is expected to set the result lattices. + virtual void + visitOperationImpl(Operation *op, + ArrayRef operandLattices, + ArrayRef resultLattices) = 0; + + /// Get the lattice element of a value. + virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; + + /// Get a read-only lattice element for a value and add it as a dependency to + /// a program point. + const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point, + Value value); + + /// Mark the given lattice elements as having reached their pessimistic + /// fixpoints and propagate an update if any changed. + void markAllPessimisticFixpoint(ArrayRef lattices); + + /// Join the lattice element and propagate and update if it changed. + void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); + +private: + /// Recursively initialize the analysis on nested operations and blocks. + LogicalResult initializeRecursively(Operation *op); + + /// Visit an operation. If this is a call operation or an operation with + /// region control-flow, then its result lattices are set accordingly. + /// Otherwise, the operation transfer function is invoked. + void visitOperation(Operation *op); + + /// Visit a block to compute the lattice values of its arguments. If this is + /// an entry block, then the argument values are determined from the block's + /// "predecessors" as set by `PredecessorState`. The predecessors can be + /// region terminators or callable callsites. Otherwise, the values are + /// determined from block predecessors. + void visitBlock(Block *block); + + /// Visit a program point `point` with predecessors within a region branch + /// operation `branch`, which can either be the entry block of one of the + /// regions or the parent operation itself, and set either the argument or + /// parent result lattices. + void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch, + Optional successorIndex, + ArrayRef lattices); +}; + +//===----------------------------------------------------------------------===// +// SparseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// A sparse (forward) data-flow analysis for propagating SSA value lattices +/// across the IR by implementing transfer functions for operations. +/// +/// `StateT` is expected to be a subclass of `AbstractSparseLattice`. +template +class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis { + static_assert( + std::is_base_of::value, + "analysis state class expected to subclass AbstractSparseLattice"); + +public: + explicit SparseDataFlowAnalysis(DataFlowSolver &solver) + : AbstractSparseDataFlowAnalysis(solver) {} + + /// Visit an operation with the lattices of its operands. This function is + /// expected to set the lattices of the operation's results. + virtual void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) = 0; + +protected: + /// Get the lattice element for a value. + StateT *getLatticeElement(Value value) override { + return getOrCreate(value); + } + + /// Get the lattice element for a value and create a dependency on the + /// provided program point. + const StateT *getLatticeElementFor(ProgramPoint point, Value value) { + return static_cast( + AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value)); + } + + /// Mark the lattice elements of a range of values as having reached their + /// pessimistic fixpoint. + void markAllPessimisticFixpoint(ArrayRef lattices) { + AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( + {reinterpret_cast(lattices.begin()), + lattices.size()}); + } + +private: + /// Type-erased wrappers that convert the abstract lattice operands to derived + /// lattices and invoke the virtual hooks operating on the derived lattices. + void visitOperationImpl( + Operation *op, ArrayRef operandLattices, + ArrayRef resultLattices) override { + visitOperation( + op, + {reinterpret_cast(operandLattices.begin()), + operandLattices.size()}, + {reinterpret_cast(resultLattices.begin()), + resultLattices.size()}); + } +}; + } // end namespace dataflow } // end namespace mlir diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp @@ -7,6 +7,10 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/IR/OpDefinition.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "constant-propagation" using namespace mlir; using namespace mlir::dataflow; @@ -20,3 +24,68 @@ return constant.print(os); os << ""; } + +//===----------------------------------------------------------------------===// +// SparseConstantPropagation +//===----------------------------------------------------------------------===// + +void SparseConstantPropagation::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); + + // Don't try to simulate the results of a region operation as we can't + // guarantee that folding will be out-of-place. We don't allow in-place + // folds as the desire here is for simulated execution, and not general + // folding. + if (op->getNumRegions()) + return; + + SmallVector constantOperands; + constantOperands.reserve(op->getNumOperands()); + for (auto *operandLattice : operands) + constantOperands.push_back(operandLattice->getValue().getConstantValue()); + + // Save the original operands and attributes just in case the operation + // folds in-place. The constant passed in may not correspond to the real + // runtime value, so in-place updates are not allowed. + SmallVector originalOperands(op->getOperands()); + DictionaryAttr originalAttrs = op->getAttrDictionary(); + + // Simulate the result of folding this operation to a constant. If folding + // fails or was an in-place fold, mark the results as overdefined. + SmallVector foldResults; + foldResults.reserve(op->getNumResults()); + if (failed(op->fold(constantOperands, foldResults))) { + markAllPessimisticFixpoint(results); + return; + } + + // If the folding was in-place, mark the results as overdefined and reset + // the operation. We don't allow in-place folds as the desire here is for + // simulated execution, and not general folding. + if (foldResults.empty()) { + op->setOperands(originalOperands); + op->setAttrs(originalAttrs); + return; + } + + // Merge the fold results into the lattice for this operation. + assert(foldResults.size() == op->getNumResults() && "invalid result size"); + for (const auto it : llvm::zip(results, foldResults)) { + Lattice *lattice = std::get<0>(it); + + // Merge in the result of the fold, either a constant or a value. + OpFoldResult foldResult = std::get<1>(it); + if (Attribute attr = foldResult.dyn_cast()) { + LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); + propagateIfChanged(lattice, + lattice->join(ConstantValue(attr, op->getDialect()))); + } else { + LLVM_DEBUG(llvm::dbgs() + << "Folded to value: " << foldResult.get() << "\n"); + AbstractSparseDataFlowAnalysis::join( + lattice, *getLatticeElement(foldResult.get())); + } + } +} diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -59,6 +59,23 @@ os << " " << *op << "\n"; } +ChangeResult PredecessorState::join(Operation *predecessor) { + return knownPredecessors.insert(predecessor) ? ChangeResult::Change + : ChangeResult::NoChange; +} + +ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) { + ChangeResult result = join(predecessor); + if (!inputs.empty()) { + ValueRange &curInputs = successorInputs[predecessor]; + if (curInputs != inputs) { + curInputs = inputs; + result |= ChangeResult::Change; + } + } + return result; +} + //===----------------------------------------------------------------------===// // CFGEdge //===----------------------------------------------------------------------===// @@ -333,14 +350,18 @@ SmallVector successors; branch.getSuccessorRegions(/*index=*/{}, *operands, successors); for (const RegionSuccessor &successor : successors) { + // The successor can be either an entry block or the parent operation. + ProgramPoint point = successor.getSuccessor() + ? &successor.getSuccessor()->front() + : ProgramPoint(branch); // Mark the entry block as executable. - Region *region = successor.getSuccessor(); - assert(region && "expected a region successor"); - auto *state = getOrCreate(®ion->front()); + auto *state = getOrCreate(point); propagateIfChanged(state, state->setToLive()); // Add the parent op as a predecessor. - auto *predecessors = getOrCreate(®ion->front()); - propagateIfChanged(predecessors, predecessors->join(branch)); + auto *predecessors = getOrCreate(point); + propagateIfChanged( + predecessors, + predecessors->join(branch, successor.getSuccessorInputs())); } } @@ -366,7 +387,8 @@ // Add this terminator as a predecessor to the parent op. predecessors = getOrCreate(branch); } - propagateIfChanged(predecessors, predecessors->join(op)); + propagateIfChanged(predecessors, + predecessors->join(op, successor.getSuccessorInputs())); } } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -7,6 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" using namespace mlir; using namespace mlir::dataflow; @@ -21,3 +24,265 @@ for (DataFlowAnalysis *analysis : useDefSubscribers) solver->enqueue({user, analysis}); } + +//===----------------------------------------------------------------------===// +// AbstractSparseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis( + DataFlowSolver &solver) + : DataFlowAnalysis(solver) { + registerPointKind(); +} + +LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) { + // Mark the entry block arguments as having reached their pessimistic + // fixpoints. + for (Region ®ion : top->getRegions()) { + if (region.empty()) + continue; + for (Value argument : region.front().getArguments()) + markAllPessimisticFixpoint(getLatticeElement(argument)); + } + + return initializeRecursively(top); +} + +LogicalResult +AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) { + // Initialize the analysis by visiting every owner of an SSA value (all + // operations and blocks). + visitOperation(op); + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + getOrCreate(&block)->blockContentSubscribe(this); + visitBlock(&block); + for (Operation &op : block) + if (failed(initializeRecursively(&op))) + return failure(); + } + } + + return success(); +} + +LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) { + if (Operation *op = point.dyn_cast()) + visitOperation(op); + else if (Block *block = point.dyn_cast()) + visitBlock(block); + else + return failure(); + return success(); +} + +void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) { + // Exit early on operations with no results. + if (op->getNumResults() == 0) + return; + + // If the containing block is not executable, bail out. + if (!getOrCreate(op->getBlock())->isLive()) + return; + + // Get the result lattices. + SmallVector resultLattices; + resultLattices.reserve(op->getNumResults()); + // Track whether all results have reached their fixpoint. + bool allAtFixpoint = true; + for (Value result : op->getResults()) { + AbstractSparseLattice *resultLattice = getLatticeElement(result); + allAtFixpoint &= resultLattice->isAtFixpoint(); + resultLattices.push_back(resultLattice); + } + // If all result lattices have reached a fixpoint, there is nothing to do. + if (allAtFixpoint) + return; + + // The results of a region branch operation are determined by control-flow. + if (auto branch = dyn_cast(op)) { + return visitRegionSuccessors({branch}, branch, + /*successorIndex=*/llvm::None, resultLattices); + } + + // The results of a call operation are determined by the callgraph. + if (auto call = dyn_cast(op)) { + const auto *predecessors = getOrCreateFor(op, call); + // If not all return sites are known, then conservatively assume we can't + // reason about the data-flow. + if (!predecessors->allPredecessorsKnown()) + return markAllPessimisticFixpoint(resultLattices); + for (Operation *predecessor : predecessors->getKnownPredecessors()) + for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) + join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); + return; + } + + // Grab the lattice elements of the operands. + SmallVector operandLattices; + operandLattices.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + AbstractSparseLattice *operandLattice = getLatticeElement(operand); + operandLattice->useDefSubscribe(this); + // If any of the operand states are not initialized, bail out. + if (operandLattice->isUninitialized()) + return; + operandLattices.push_back(operandLattice); + } + + // Invoke the operation transfer function. + visitOperationImpl(op, operandLattices, resultLattices); +} + +void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) { + // Exit early on blocks with no arguments. + if (block->getNumArguments() == 0) + return; + + // If the block is not executable, bail out. + if (!getOrCreate(block)->isLive()) + return; + + // Get the argument lattices. + SmallVector argLattices; + argLattices.reserve(block->getNumArguments()); + bool allAtFixpoint = true; + for (BlockArgument argument : block->getArguments()) { + AbstractSparseLattice *argLattice = getLatticeElement(argument); + allAtFixpoint &= argLattice->isAtFixpoint(); + argLattices.push_back(argLattice); + } + // If all argument lattices have reached their fixpoints, then there is + // nothing to do. + if (allAtFixpoint) + return; + + // The argument lattices of entry blocks are set by region control-flow or the + // callgraph. + if (block->isEntryBlock()) { + // Check if this block is the entry block of a callable region. + auto callable = dyn_cast(block->getParentOp()); + if (callable && callable.getCallableRegion() == block->getParent()) { + const auto *callsites = getOrCreateFor(block, callable); + // If not all callsites are known, conservatively mark all lattices as + // having reached their pessimistic fixpoints. + if (!callsites->allPredecessorsKnown()) + return markAllPessimisticFixpoint(argLattices); + for (Operation *callsite : callsites->getKnownPredecessors()) { + auto call = cast(callsite); + for (auto it : llvm::zip(call.getArgOperands(), argLattices)) + join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); + } + return; + } + + // Check if the lattices can be determined from region control flow. + if (auto branch = dyn_cast(block->getParentOp())) { + return visitRegionSuccessors( + block, branch, block->getParent()->getRegionNumber(), argLattices); + } + + // Otherwise, we can't reason about the data-flow. + return markAllPessimisticFixpoint(argLattices); + } + + // Iterate over the predecessors of the non-entry block. + for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); + it != e; ++it) { + Block *predecessor = *it; + + // If the edge from the predecessor block to the current block is not live, + // bail out. + auto *edgeExecutable = + getOrCreate(getProgramPoint(predecessor, block)); + edgeExecutable->blockContentSubscribe(this); + if (!edgeExecutable->isLive()) + continue; + + // Check if we can reason about the data-flow from the predecessor. + if (auto branch = + dyn_cast(predecessor->getTerminator())) { + SuccessorOperands operands = + branch.getSuccessorOperands(it.getSuccessorIndex()); + for (auto &it : llvm::enumerate(argLattices)) { + if (Value operand = operands[it.index()]) { + join(it.value(), *getLatticeElementFor(block, operand)); + } else { + // Conservatively mark internally produced arguments as having reached + // their pessimistic fixpoint. + markAllPessimisticFixpoint(it.value()); + } + } + } else { + return markAllPessimisticFixpoint(argLattices); + } + } +} + +void AbstractSparseDataFlowAnalysis::visitRegionSuccessors( + ProgramPoint point, RegionBranchOpInterface branch, + Optional successorIndex, + ArrayRef lattices) { + const auto *predecessors = getOrCreateFor(point, point); + assert(predecessors->allPredecessorsKnown() && + "unexpected unresolved region successors"); + + for (Operation *op : predecessors->getKnownPredecessors()) { + // Get the incoming successor operands. + Optional operands; + + // Check if the predecessor is the parent op. + if (op == branch) { + operands = branch.getSuccessorEntryOperands(successorIndex); + // Otherwise, try to deduce the operands from a region return-like op. + } else { + assert(op->hasTrait() && "expected a terminator"); + if (isRegionReturnLike(op)) + operands = getRegionBranchSuccessorOperands(op, successorIndex); + } + + if (!operands) { + // We can't reason about the data-flow. + return markAllPessimisticFixpoint(lattices); + } + + ValueRange inputs = predecessors->getSuccessorInputs(op); + assert(inputs.size() == operands->size() && + "expected the same number of successor inputs as operands"); + + // TODO: This was updated to be exposed upstream. + unsigned firstIndex = 0; + if (inputs.size() != lattices.size()) { + if (inputs.empty()) { + markAllPessimisticFixpoint(lattices); + return; + } + firstIndex = inputs.front().cast().getArgNumber(); + markAllPessimisticFixpoint(lattices.take_front(firstIndex)); + markAllPessimisticFixpoint( + lattices.drop_front(firstIndex + inputs.size())); + } + + for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) + join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it))); + } +} + +const AbstractSparseLattice * +AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, + Value value) { + AbstractSparseLattice *state = getLatticeElement(value); + addDependency(state, point); + return state; +} + +void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( + ArrayRef lattices) { + for (AbstractSparseLattice *lattice : lattices) + propagateIfChanged(lattice, lattice->markPessimisticFixpoint()); +} + +void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs, + const AbstractSparseLattice &rhs) { + propagateIfChanged(lhs, lhs->join(rhs)); +} diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -15,150 +15,17 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/Passes.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "sccp" using namespace mlir; - -//===----------------------------------------------------------------------===// -// SCCP Analysis -//===----------------------------------------------------------------------===// - -namespace { -struct SCCPLatticeValue { - SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr) - : constant(constant), constantDialect(dialect) {} - - /// The pessimistic state of SCCP is non-constant. - static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) { - return SCCPLatticeValue(); - } - static SCCPLatticeValue getPessimisticValueState(Value value) { - return SCCPLatticeValue(); - } - - /// Equivalence for SCCP only accounts for the constant, not the originating - /// dialect. - bool operator==(const SCCPLatticeValue &rhs) const { - return constant == rhs.constant; - } - - /// To join the state of two values, we simply check for equivalence. - static SCCPLatticeValue join(const SCCPLatticeValue &lhs, - const SCCPLatticeValue &rhs) { - return lhs == rhs ? lhs : SCCPLatticeValue(); - } - - /// The constant attribute value. - Attribute constant; - - /// The dialect the constant originated from. This is not used as part of the - /// key, and is only needed to materialize the held constant if necessary. - Dialect *constantDialect; -}; - -struct SCCPAnalysis : public ForwardDataFlowAnalysis { - using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; - ~SCCPAnalysis() override = default; - - ChangeResult - visitOperation(Operation *op, - ArrayRef *> operands) final { - - LLVM_DEBUG(llvm::dbgs() << "SCCP: Visiting operation: " << *op << "\n"); - - // Don't try to simulate the results of a region operation as we can't - // guarantee that folding will be out-of-place. We don't allow in-place - // folds as the desire here is for simulated execution, and not general - // folding. - if (op->getNumRegions()) - return markAllPessimisticFixpoint(op->getResults()); - - SmallVector constantOperands( - llvm::map_range(operands, [](LatticeElement *value) { - return value->getValue().constant; - })); - - // Save the original operands and attributes just in case the operation - // folds in-place. The constant passed in may not correspond to the real - // runtime value, so in-place updates are not allowed. - SmallVector originalOperands(op->getOperands()); - DictionaryAttr originalAttrs = op->getAttrDictionary(); - - // Simulate the result of folding this operation to a constant. If folding - // fails or was an in-place fold, mark the results as overdefined. - SmallVector foldResults; - foldResults.reserve(op->getNumResults()); - if (failed(op->fold(constantOperands, foldResults))) - return markAllPessimisticFixpoint(op->getResults()); - - // If the folding was in-place, mark the results as overdefined and reset - // the operation. We don't allow in-place folds as the desire here is for - // simulated execution, and not general folding. - if (foldResults.empty()) { - op->setOperands(originalOperands); - op->setAttrs(originalAttrs); - return markAllPessimisticFixpoint(op->getResults()); - } - - // Merge the fold results into the lattice for this operation. - assert(foldResults.size() == op->getNumResults() && "invalid result size"); - Dialect *dialect = op->getDialect(); - ChangeResult result = ChangeResult::NoChange; - for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { - LatticeElement &lattice = - getLatticeElement(op->getResult(i)); - - // Merge in the result of the fold, either a constant or a value. - OpFoldResult foldResult = foldResults[i]; - if (Attribute attr = foldResult.dyn_cast()) - result |= lattice.join(SCCPLatticeValue(attr, dialect)); - else - result |= lattice.join(getLatticeElement(foldResult.get())); - } - return result; - } - - /// Implementation of `getSuccessorsForOperands` that uses constant operands - /// to potentially remove dead successors. - LogicalResult getSuccessorsForOperands( - BranchOpInterface branch, - ArrayRef *> operands, - SmallVectorImpl &successors) final { - SmallVector constantOperands( - llvm::map_range(operands, [](LatticeElement *value) { - return value->getValue().constant; - })); - if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { - successors.push_back(singleSucc); - return success(); - } - return failure(); - } - - /// Implementation of `getSuccessorsForOperands` that uses constant operands - /// to potentially remove dead region successors. - void getSuccessorsForOperands( - RegionBranchOpInterface branch, Optional sourceIndex, - ArrayRef *> operands, - SmallVectorImpl &successors) final { - SmallVector constantOperands( - llvm::map_range(operands, [](LatticeElement *value) { - return value->getValue().constant; - })); - branch.getSuccessorRegions(sourceIndex, constantOperands, successors); - } -}; -} // namespace +using namespace mlir::dataflow; //===----------------------------------------------------------------------===// // SCCP Rewrites @@ -167,21 +34,21 @@ /// Replace the given value with a constant if the corresponding lattice /// represents a constant. Returns success if the value was replaced, failure /// otherwise. -static LogicalResult replaceWithConstant(SCCPAnalysis &analysis, +static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &builder, OperationFolder &folder, Value value) { - LatticeElement *lattice = - analysis.lookupLatticeElement(value); + auto *lattice = solver.lookupState>(value); if (!lattice) return failure(); - SCCPLatticeValue &latticeValue = lattice->getValue(); - if (!latticeValue.constant) + const ConstantValue &latticeValue = lattice->getValue(); + if (!latticeValue.getConstantValue()) return failure(); // Attempt to materialize a constant for the given value. - Dialect *dialect = latticeValue.constantDialect; - Value constant = folder.getOrCreateConstant( - builder, dialect, latticeValue.constant, value.getType(), value.getLoc()); + Dialect *dialect = latticeValue.getConstantDialect(); + Value constant = folder.getOrCreateConstant(builder, dialect, + latticeValue.getConstantValue(), + value.getType(), value.getLoc()); if (!constant) return failure(); @@ -192,7 +59,7 @@ /// Rewrite the given regions using the computing analysis. This replaces the /// uses of all values that have been computed to be constant, and erases as /// many newly dead operations. -static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, +static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef initialRegions) { SmallVector worklist; auto addToWorklist = [&](MutableArrayRef regions) { @@ -216,7 +83,7 @@ bool replacedAll = op.getNumResults() != 0; for (Value res : op.getResults()) replacedAll &= - succeeded(replaceWithConstant(analysis, builder, folder, res)); + succeeded(replaceWithConstant(solver, builder, folder, res)); // If all of the results of the operation were replaced, try to erase // the operation completely. @@ -233,7 +100,7 @@ // Replace any block arguments with constants. builder.setInsertionPointToStart(block); for (BlockArgument arg : block->getArguments()) - (void)replaceWithConstant(analysis, builder, folder, arg); + (void)replaceWithConstant(solver, builder, folder, arg); } } @@ -250,9 +117,12 @@ void SCCP::runOnOperation() { Operation *op = getOperation(); - SCCPAnalysis analysis(op->getContext()); - analysis.run(op); - rewrite(analysis, op->getContext(), op->getRegions()); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + rewrite(solver, op->getContext(), op->getRegions()); } std::unique_ptr mlir::createSCCPPass() {