diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h @@ -46,8 +46,8 @@ /// Print the constant value. void print(raw_ostream &os) const; - /// The pessimistic value state of the constant value is unknown. - static ConstantValue getPessimisticValueState(Value value) { return {}; } + /// The state where the constant value is unknown. + static ConstantValue getUnknownConstant() { return {}; } /// The union with another constant value is null if they are different, and /// the same if they are the same. @@ -79,6 +79,8 @@ void visitOperation(Operation *op, ArrayRef *> operands, ArrayRef *> results) override; + + void setToEntryState(Lattice *lattice) override; }; } // end namespace dataflow diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -38,10 +38,6 @@ /// Join the lattice across control-flow or callgraph edges. virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0; - - /// Reset the dense lattice to a pessimistic value. This occurs when the - /// analysis cannot reason about the data-flow. - virtual ChangeResult reset() = 0; }; //===----------------------------------------------------------------------===// @@ -88,11 +84,9 @@ const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent, ProgramPoint point); - /// Mark the dense lattice as having reached its pessimistic fixpoint and - /// propagate an update if it changed. - void reset(AbstractDenseLattice *lattice) { - propagateIfChanged(lattice, lattice->reset()); - } + /// Set the dense lattice at control flow entry point and propagate an update + /// if it changed. + virtual void setToEntryState(AbstractDenseLattice *lattice) = 0; /// Join a lattice with another and propagate an update if it changed. void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) { @@ -147,6 +141,11 @@ return getOrCreate(point); } + virtual void setToEntryState(LatticeT *lattice) = 0; + void setToEntryState(AbstractDenseLattice *lattice) override { + setToEntryState(static_cast(lattice)); + } + private: /// Type-erased wrappers that convert the abstract dense lattice to a derived /// lattice and invoke the virtual hooks operating on the derived lattice. diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -27,7 +27,7 @@ /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) /// range that is used to mark the value as unable to be analyzed further, /// where `t` is the type of `value`. - static IntegerValueRange getPessimisticValueState(Value value); + static IntegerValueRange getMaxRange(Value value); /// Create an integer value range lattice value. IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {} @@ -74,6 +74,12 @@ public: using SparseDataFlowAnalysis::SparseDataFlowAnalysis; + void setToEntryState(IntegerValueRangeLattice *lattice) override { + propagateIfChanged( + lattice, + lattice->setValue(IntegerValueRange::getMaxRange(lattice->getPoint()))); + } + /// Visit an operation. Invoke the transfer function on each operation that /// implements `InferIntRangeInterface`. void visitOperation(Operation *op, diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -34,15 +34,15 @@ /// Lattices can only be created for values. AbstractSparseLattice(Value value) : AnalysisState(value) {} + /// Return the program point this lattice is located at. + Value getPoint() const { + return AnalysisState::getPoint().get(); + } + /// Join the information contained in 'rhs' into this lattice. Returns /// if the value of the lattice changed. virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0; - /// Mark the lattice element as having reached a pessimistic fixpoint. This - /// means that the lattice may potentially have conflicting value states, and - /// only the most conservative value should be relied on. - virtual ChangeResult markPessimisticFixpoint() = 0; - /// When the lattice gets updated, propagate an update to users of the value /// using its use-def chain to subscribed analyses. void onUpdate(DataFlowSolver *solver) const override; @@ -76,23 +76,32 @@ template class Lattice : public AbstractSparseLattice { public: - /// Construct a lattice with a known value. - explicit Lattice(Value value) - : AbstractSparseLattice(value), - knownValue(ValueT::getPessimisticValueState(value)) {} + using AbstractSparseLattice::AbstractSparseLattice; + + /// Return the program point this lattice is located at. + Value getPoint() const { return point.get(); } /// Return the value held by this lattice. This requires that the value is /// initialized. ValueT &getValue() { assert(!isUninitialized() && "expected known lattice element"); - return *optimisticValue; + return *value; } const ValueT &getValue() const { return const_cast *>(this)->getValue(); } + /// Set the value held by this lattice. + ChangeResult setValue(ValueT &&value) { + if (this->value == value) + return ChangeResult::NoChange; + + this->value = std::move(value); + return ChangeResult::Change; + } + /// Returns true if the value of this lattice hasn't yet been initialized. - bool isUninitialized() const override { return !optimisticValue.has_value(); } + bool isUninitialized() const override { return !value.has_value(); } /// Join the information contained in the 'rhs' lattice into this /// lattice. Returns if the state of the current lattice changed. @@ -113,56 +122,37 @@ ChangeResult join(const ValueT &rhs) { // If the current lattice is uninitialized, copy the rhs value. if (isUninitialized()) { - optimisticValue = rhs; + value = rhs; return ChangeResult::Change; } // Otherwise, join rhs with the current optimistic value. - ValueT newValue = ValueT::join(*optimisticValue, rhs); - assert(ValueT::join(newValue, *optimisticValue) == newValue && + ValueT newValue = ValueT::join(*value, rhs); + assert(ValueT::join(newValue, *value) == newValue && "expected `join` to be monotonic"); assert(ValueT::join(newValue, rhs) == newValue && "expected `join` to be monotonic"); // Update the current optimistic value if something changed. - if (newValue == optimisticValue) + if (newValue == value) return ChangeResult::NoChange; - optimisticValue = newValue; - return ChangeResult::Change; - } - - /// Mark the lattice element as having reached a pessimistic fixpoint. This - /// means that the lattice may potentially have conflicting value states, - /// and only the conservatively known value state should be relied on. - ChangeResult markPessimisticFixpoint() override { - if (optimisticValue == knownValue) - return ChangeResult::NoChange; - - // For this fixed point, we take whatever we knew to be true and set that - // to our optimistic value. - optimisticValue = knownValue; + value = newValue; return ChangeResult::Change; } /// Print the lattice element. void print(raw_ostream &os) const override { - os << "["; - knownValue.print(os); - os << ", "; - if (optimisticValue) - optimisticValue->print(os); + if (value) + value->print(os); else os << ""; - os << "]"; } private: - /// The value that is conservatively known to be true. - ValueT knownValue; /// The currently computed value that is optimistically assumed to be true, /// or None if the lattice element is uninitialized. - Optional optimisticValue; + Optional value; }; //===----------------------------------------------------------------------===// @@ -213,9 +203,9 @@ const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point, Value value); - /// Mark the given lattice elements as having reached their pessimistic - /// fixpoints and propagate an update if any changed. - void markAllPessimisticFixpoint(ArrayRef lattices); + /// Set the given lattice element(s) at control flow entry point(s). + virtual void setToEntryState(AbstractSparseLattice *lattice) = 0; + void setAllToEntryStates(ArrayRef lattices); /// Join the lattice element and propagate and update if it changed. void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs); @@ -278,8 +268,8 @@ const RegionSuccessor &successor, ArrayRef argLattices, unsigned firstIndex) { - markAllPessimisticFixpoint(argLattices.take_front(firstIndex)); - markAllPessimisticFixpoint(argLattices.drop_front( + setAllToEntryStates(argLattices.take_front(firstIndex)); + setAllToEntryStates(argLattices.drop_front( firstIndex + successor.getSuccessorInputs().size())); } @@ -296,10 +286,10 @@ AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value)); } - /// Mark the lattice elements of a range of values as having reached their - /// pessimistic fixpoint. - void markAllPessimisticFixpoint(ArrayRef lattices) { - AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( + /// Set the given lattice element(s) at control flow entry point(s). + virtual void setToEntryState(StateT *lattice) = 0; + void setAllToEntryStates(ArrayRef lattices) { + AbstractSparseDataFlowAnalysis::setAllToEntryStates( {reinterpret_cast(lattices.begin()), lattices.size()}); } @@ -327,6 +317,9 @@ argLattices.size()}, firstIndex); } + void setToEntryState(AbstractSparseLattice *lattice) override { + return setToEntryState(reinterpret_cast(lattice)); + } }; } // end namespace dataflow diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -288,6 +288,9 @@ /// Create the analysis state at the given program point. AnalysisState(ProgramPoint point) : point(point) {} + /// Returns the program point this static is located at. + ProgramPoint getPoint() const { return point; } + /// Returns true if the analysis state is uninitialized. virtual bool isUninitialized() const = 0; diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp @@ -57,7 +57,7 @@ SmallVector foldResults; foldResults.reserve(op->getNumResults()); if (failed(op->fold(constantOperands, foldResults))) { - markAllPessimisticFixpoint(results); + setAllToEntryStates(results); return; } @@ -67,7 +67,7 @@ if (foldResults.empty()) { op->setOperands(originalOperands); op->setAttrs(originalAttrs); - markAllPessimisticFixpoint(results); + setAllToEntryStates(results); return; } @@ -90,3 +90,9 @@ } } } + +void SparseConstantPropagation::setToEntryState( + Lattice *lattice) { + propagateIfChanged(lattice, + lattice->setValue(ConstantValue::getUnknownConstant())); +} diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -62,7 +62,7 @@ // If not all return sites are known, then conservatively assume we can't // reason about the data-flow. if (!predecessors->allPredecessorsKnown()) - return reset(after); + return setToEntryState(after); for (Operation *predecessor : predecessors->getKnownPredecessors()) join(after, *getLatticeFor(op, predecessor)); return; @@ -100,7 +100,7 @@ // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. if (!callsites->allPredecessorsKnown()) - return reset(after); + return setToEntryState(after); for (Operation *callsite : callsites->getKnownPredecessors()) { // Get the dense lattice before the callsite. if (Operation *prev = callsite->getPrevNode()) @@ -116,7 +116,7 @@ return visitRegionBranchOperation(block, branch, after); // Otherwise, we can't reason about the data-flow. - return reset(after); + return setToEntryState(after); } // Join the state with the state after the block's predecessors. diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -23,7 +23,7 @@ using namespace mlir; using namespace mlir::dataflow; -IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) { +IntegerValueRange IntegerValueRange::getMaxRange(Value value) { unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType()); APInt umin = APInt::getMinValue(width); APInt umax = APInt::getMaxValue(width); @@ -41,7 +41,8 @@ auto value = point.get(); auto *cv = solver->getOrCreateState>(value); if (!constant) - return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint()); + return solver->propagateIfChanged( + cv, cv->setValue(ConstantValue::getUnknownConstant())); Dialect *dialect; if (auto *parent = value.getDefiningOp()) @@ -60,11 +61,13 @@ // integer results bool hasIntegerResult = false; for (auto it : llvm::zip(results, op->getResults())) { - if (std::get<1>(it).getType().isIntOrIndex()) { + Value value = std::get<1>(it); + if (value.getType().isIntOrIndex()) { hasIntegerResult = true; } else { - propagateIfChanged(std::get<0>(it), - std::get<0>(it)->markPessimisticFixpoint()); + IntegerValueRangeLattice *lattice = std::get<0>(it); + propagateIfChanged( + lattice, lattice->setValue(IntegerValueRange::getMaxRange(value))); } } if (!hasIntegerResult) @@ -72,7 +75,7 @@ auto inferrable = dyn_cast(op); if (!inferrable) - return markAllPessimisticFixpoint(results); + return setAllToEntryStates(results); LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); SmallVector argRanges( @@ -104,7 +107,7 @@ if (isYieldedResult && oldRange.has_value() && !(lattice->getValue() == *oldRange)) { LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); - changed |= lattice->markPessimisticFixpoint(); + changed |= lattice->setValue(IntegerValueRange::getMaxRange(v)); } propagateIfChanged(lattice, changed); }; @@ -146,7 +149,7 @@ }); if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) { LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); - changed |= lattice->markPessimisticFixpoint(); + changed |= lattice->setValue(IntegerValueRange::getMaxRange(v)); } propagateIfChanged(lattice, changed); }; diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -41,7 +41,7 @@ if (region.empty()) continue; for (Value argument : region.front().getArguments()) - markAllPessimisticFixpoint(getLatticeElement(argument)); + setAllToEntryStates(getLatticeElement(argument)); } return initializeRecursively(top); @@ -104,7 +104,7 @@ // If not all return sites are known, then conservatively assume we can't // reason about the data-flow. if (!predecessors->allPredecessorsKnown()) - return markAllPessimisticFixpoint(resultLattices); + return setAllToEntryStates(resultLattices); for (Operation *predecessor : predecessors->getKnownPredecessors()) for (auto it : llvm::zip(predecessor->getOperands(), resultLattices)) join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it))); @@ -154,7 +154,7 @@ // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. if (!callsites->allPredecessorsKnown()) - return markAllPessimisticFixpoint(argLattices); + return setAllToEntryStates(argLattices); for (Operation *callsite : callsites->getKnownPredecessors()) { auto call = cast(callsite); for (auto it : llvm::zip(call.getArgOperands(), argLattices)) @@ -197,13 +197,13 @@ if (Value operand = operands[it.index()]) { join(it.value(), *getLatticeElementFor(block, operand)); } else { - // Conservatively mark internally produced arguments as having reached - // their pessimistic fixpoint. - markAllPessimisticFixpoint(it.value()); + // Conservatively consider internally produced arguments as entry + // points. + setAllToEntryStates(it.value()); } } } else { - return markAllPessimisticFixpoint(argLattices); + return setAllToEntryStates(argLattices); } } } @@ -231,7 +231,7 @@ if (!operands) { // We can't reason about the data-flow. - return markAllPessimisticFixpoint(lattices); + return setAllToEntryStates(lattices); } ValueRange inputs = predecessors->getSuccessorInputs(op); @@ -273,10 +273,10 @@ return state; } -void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint( +void AbstractSparseDataFlowAnalysis::setAllToEntryStates( ArrayRef lattices) { for (AbstractSparseLattice *lattice : lattices) - propagateIfChanged(lattice, lattice->markPessimisticFixpoint()); + setToEntryState(lattice); } void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs, diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp @@ -84,19 +84,18 @@ constant, constant->join(ConstantValue(value, op->getDialect()))); return success(); } - markAllPessimisticFixpoint(op->getResults()); + setAllToUnknownConstants(op->getResults()); for (Region ®ion : op->getRegions()) - markAllPessimisticFixpoint(region.getArguments()); + setAllToUnknownConstants(region.getArguments()); return success(); } - /// Mark the constant values of all given values as having reached a - /// pessimistic fixpoint. - void markAllPessimisticFixpoint(ValueRange values) { + /// Set all given values as not constants. + void setAllToUnknownConstants(ValueRange values) { for (Value value : values) { - auto *constantValue = getOrCreate>(value); - propagateIfChanged(constantValue, - constantValue->markPessimisticFixpoint()); + auto *constant = getOrCreate>(value); + propagateIfChanged( + constant, constant->setValue(ConstantValue::getUnknownConstant())); } } }; diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp @@ -20,11 +20,6 @@ /// This lattice represents a single underlying value for an SSA value. class UnderlyingValue { public: - /// The pessimistic underlying value of a value is itself. - static UnderlyingValue getPessimisticValueState(Value value) { - return {value}; - } - /// Create an underlying value state with a known underlying value. UnderlyingValue(Value underlyingValue = {}) : underlyingValue(underlyingValue) {} @@ -61,9 +56,8 @@ /// The lattice is always initialized. bool isUninitialized() const override { return false; } - /// Mark the lattice as having reached its pessimistic fixpoint. That is, the - /// last modifications of all memory resources are unknown. - ChangeResult reset() override { + /// Clear all modifications. + ChangeResult reset() { if (lastMods.empty()) return ChangeResult::NoChange; lastMods.clear(); @@ -131,6 +125,12 @@ /// resource, then its reaching definition is set to the written value. void visitOperation(Operation *op, const LastModification &before, LastModification *after) override; + + /// At an entry point, the last modifications of all memory resources are + /// unknown. + void setToEntryState(LastModification *lattice) override { + propagateIfChanged(lattice, lattice->reset()); + } }; /// Define the lattice class explicitly to provide a type ID. @@ -152,7 +152,13 @@ void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override { - markAllPessimisticFixpoint(results); + setAllToEntryStates(results); + } + + /// At an entry point, the underlying value of a value is itself. + void setToEntryState(UnderlyingValueLattice *lattice) override { + propagateIfChanged(lattice, + lattice->setValue(UnderlyingValue{lattice->getPoint()})); } }; } // end anonymous namespace @@ -181,7 +187,7 @@ // If we can't reason about the memory effects, then conservatively assume we // can't deduce anything about the last modifications. if (!memory) - return reset(after); + return setToEntryState(after); SmallVector effects; memory.getEffects(effects); @@ -193,7 +199,7 @@ // If we see an effect on anything other than a value, assume we can't // deduce anything about the last modifications. if (!value) - return reset(after); + return setToEntryState(after); value = getMostUnderlyingValue(value, [&](Value value) { return getOrCreateFor(op, value);