diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h @@ -28,15 +28,24 @@ /// This lattice value represents a known constant value of a lattice. class ConstantValue { public: + /// Construct a constant value as uninitialized. + explicit ConstantValue() = default; + /// Construct a constant value with a known constant. - ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr) - : constant(knownValue), dialect(dialect) {} + explicit ConstantValue(Attribute constant, Dialect *dialect) + : constant(constant), dialect(dialect) {} /// Get the constant value. Returns null if no value was determined. - Attribute getConstantValue() const { return constant; } + Attribute getConstantValue() const { + assert(!isUninitialized()); + return *constant; + } /// Get the dialect instance that can be used to materialize the constant. - Dialect *getConstantDialect() const { return dialect; } + Dialect *getConstantDialect() const { + assert(!isUninitialized()); + return dialect; + } /// Compare the constant values. bool operator==(const ConstantValue &rhs) const { @@ -46,20 +55,35 @@ /// Print the constant value. void print(raw_ostream &os) const; + /// The state where the constant value is uninitialized. This happens when the + /// state hasn't been set during the analysis. + static ConstantValue getUninitialized() { return ConstantValue{}; } + + /// Whether the state is uninitialized. + bool isUninitialized() const { return !constant.has_value(); } + /// The state where the constant value is unknown. - static ConstantValue getUnknownConstant() { return {}; } + static ConstantValue getUnknownConstant() { + return ConstantValue{/*constant=*/nullptr, /*dialect=*/nullptr}; + } /// The union with another constant value is null if they are different, and /// the same if they are the same. static ConstantValue join(const ConstantValue &lhs, const ConstantValue &rhs) { - return lhs == rhs ? lhs : ConstantValue(); + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + if (lhs == rhs) + return lhs; + return getUnknownConstant(); } private: /// The constant value. - Attribute constant; - /// An dialect instance that can be used to materialize the constant. + Optional constant; + /// A dialect instance that can be used to materialize the constant. Dialect *dialect; }; diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -38,9 +38,6 @@ public: using AnalysisState::AnalysisState; - /// The state is initialized by default. - bool isUninitialized() const override { return false; } - /// Set the state of the program point to live. ChangeResult setToLive(); @@ -95,9 +92,6 @@ public: using AnalysisState::AnalysisState; - /// The state is initialized by default. - bool isUninitialized() const override { return false; } - /// Print the known predecessors. void print(raw_ostream &os) const override; diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -30,10 +30,18 @@ static IntegerValueRange getMaxRange(Value value); /// Create an integer value range lattice value. - IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {} + IntegerValueRange(Optional value = None) + : value(std::move(value)) {} + + /// Whether the range is uninitialized. This happens when the state hasn't + /// been set during the analysis. + bool isUninitialized() const { return !value.has_value(); } /// Get the known integer value range. - const ConstantIntRanges &getValue() const { return value; } + const ConstantIntRanges &getValue() const { + assert(!isUninitialized()); + return *value; + } /// Compare two ranges. bool operator==(const IntegerValueRange &rhs) const { @@ -43,7 +51,11 @@ /// Take the union of two ranges. static IntegerValueRange join(const IntegerValueRange &lhs, const IntegerValueRange &rhs) { - return lhs.value.rangeUnion(rhs.value); + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + return IntegerValueRange{lhs.getValue().rangeUnion(rhs.getValue())}; } /// Print the integer value range. @@ -51,7 +63,7 @@ private: /// The known integer value range. - ConstantIntRanges value; + Optional value; }; /// This lattice element represents the integer value range of an SSA value. diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -81,27 +81,17 @@ /// Return the value held by this lattice. This requires that the value is /// initialized. - ValueT &getValue() { - assert(!isUninitialized() && "expected known lattice element"); - return *value; - } + ValueT &getValue() { return value; } const ValueT &getValue() const { return const_cast *>(this)->getValue(); } - /// Returns true if the value of this lattice hasn't yet been initialized. - bool isUninitialized() const override { return !value.has_value(); } - /// Join the information contained in the 'rhs' lattice into this /// lattice. Returns if the state of the current lattice changed. ChangeResult join(const AbstractSparseLattice &rhs) override { const Lattice &rhsLattice = static_cast &>(rhs); - // If rhs is uninitialized, there is nothing to do. - if (rhsLattice.isUninitialized()) - return ChangeResult::NoChange; - // Join the rhs value into this lattice. return join(rhsLattice.getValue()); } @@ -109,15 +99,9 @@ /// Join the information contained in the 'rhs' value into this /// lattice. Returns if the state of the current lattice changed. ChangeResult join(const ValueT &rhs) { - // If the current lattice is uninitialized, copy the rhs value. - if (isUninitialized()) { - value = rhs; - return ChangeResult::Change; - } - // Otherwise, join rhs with the current optimistic value. - ValueT newValue = ValueT::join(*value, rhs); - assert(ValueT::join(newValue, *value) == newValue && + ValueT newValue = ValueT::join(value, rhs); + assert(ValueT::join(newValue, value) == newValue && "expected `join` to be monotonic"); assert(ValueT::join(newValue, rhs) == newValue && "expected `join` to be monotonic"); @@ -131,17 +115,11 @@ } /// Print the lattice element. - void print(raw_ostream &os) const override { - if (value) - value->print(os); - else - os << ""; - } + void print(raw_ostream &os) const override { value.print(os); } private: - /// The currently computed value that is optimistically assumed to be true, - /// or None if the lattice element is uninitialized. - Optional value; + /// The currently computed value that is optimistically assumed to be true. + ValueT value; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -291,9 +291,6 @@ /// Returns the program point this static is located at. ProgramPoint getPoint() const { return point; } - /// Returns true if the analysis state is uninitialized. - virtual bool isUninitialized() const = 0; - /// Print the contents of the analysis state. virtual void print(raw_ostream &os) const = 0; diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp @@ -20,9 +20,15 @@ //===----------------------------------------------------------------------===// void ConstantValue::print(raw_ostream &os) const { - if (constant) - return constant.print(os); - os << ""; + if (isUninitialized()) { + os << ""; + return; + } + if (getConstantValue() == nullptr) { + os << ""; + return; + } + return getConstantValue().print(os); } //===----------------------------------------------------------------------===// @@ -45,8 +51,11 @@ SmallVector constantOperands; constantOperands.reserve(op->getNumOperands()); - for (auto *operandLattice : operands) + for (auto *operandLattice : operands) { + if (operandLattice->getValue().isUninitialized()) + return; constantOperands.push_back(operandLattice->getValue().getConstantValue()); + } // Save the original operands and attributes just in case the operation // folds in-place. The constant passed in may not correspond to the real diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -318,7 +318,7 @@ for (Value operand : op->getOperands()) { const Lattice *cv = getLattice(operand); // If any of the operands' values are uninitialized, bail out. - if (cv->isUninitialized()) + if (cv->getValue().isUninitialized()) return {}; operands.push_back(cv->getValue().getConstantValue()); } diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -74,9 +74,6 @@ before = getLatticeFor(op, prev); else before = getLatticeFor(op, op->getBlock()); - // If the incoming lattice is uninitialized, bail out. - if (before->isUninitialized()) - return; // Invoke the operation transfer function. visitOperationImpl(op, *before, after); diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -29,7 +29,7 @@ APInt umax = APInt::getMaxValue(width); APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; - return {{umin, umax, smin, smax}}; + return IntegerValueRange{ConstantIntRanges{umin, umax, smin, smax}}; } void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { @@ -57,6 +57,13 @@ void IntegerRangeAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { + // If the lattice on any operand is unitialized, bail out. + if (llvm::any_of(operands, [](const IntegerValueRangeLattice *lattice) { + return lattice->getValue().isUninitialized(); + })) { + return; + } + // Ignore non-integer outputs - return early if the op has no scalar // integer results bool hasIntegerResult = false; @@ -91,11 +98,9 @@ LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; - Optional oldRange; - if (!lattice->isUninitialized()) - oldRange = lattice->getValue(); + IntegerValueRange oldRange = lattice->getValue(); - ChangeResult changed = lattice->join(attrs); + ChangeResult changed = lattice->join(IntegerValueRange{attrs}); // Catch loop results with loop variant bounds and conservatively make // them [-inf, inf] so we don't circle around infinitely often (because @@ -104,8 +109,8 @@ bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { return op->hasTrait(); }); - if (isYieldedResult && oldRange.has_value() && - !(lattice->getValue() == *oldRange)) { + if (isYieldedResult && !oldRange.isUninitialized() && + !(lattice->getValue() == oldRange)) { LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); changed |= lattice->join(IntegerValueRange::getMaxRange(v)); } @@ -134,11 +139,9 @@ LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; - Optional oldRange; - if (!lattice->isUninitialized()) - oldRange = lattice->getValue(); + IntegerValueRange oldRange = lattice->getValue(); - ChangeResult changed = lattice->join(attrs); + ChangeResult changed = lattice->join(IntegerValueRange{attrs}); // Catch loop results with loop variant bounds and conservatively make // them [-inf, inf] so we don't circle around infinitely often (because @@ -147,7 +150,8 @@ bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { return op->hasTrait(); }); - if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) { + if (isYieldedValue && !oldRange.isUninitialized() && + !(lattice->getValue() == oldRange)) { LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); changed |= lattice->join(IntegerValueRange::getMaxRange(v)); } @@ -212,7 +216,7 @@ IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); auto ivRange = ConstantIntRanges::fromSigned(min, max); - propagateIfChanged(ivEntry, ivEntry->join(ivRange)); + propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); return; } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -117,9 +117,6 @@ for (Value operand : op->getOperands()) { AbstractSparseLattice *operandLattice = getLatticeElement(operand); operandLattice->useDefSubscribe(this); - // If any of the operand states are not initialized, bail out. - if (operandLattice->isUninitialized()) - return; operandLattices.push_back(operandLattice); } diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -43,7 +43,7 @@ OpBuilder &builder, OperationFolder &folder, Value value) { auto *lattice = solver.lookupState>(value); - if (!lattice || lattice->isUninitialized()) + if (!lattice || lattice->getValue().isUninitialized()) return failure(); const ConstantValue &latticeValue = lattice->getValue(); if (!latticeValue.getConstantValue()) diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp @@ -21,16 +21,29 @@ class UnderlyingValue { public: /// Create an underlying value state with a known underlying value. - UnderlyingValue(Value underlyingValue) : underlyingValue(underlyingValue) {} + explicit UnderlyingValue(Optional underlyingValue = None) + : underlyingValue(underlyingValue) {} + + /// Whether the state is uninitialized. + bool isUninitialized() const { return !underlyingValue.has_value(); } /// Returns the underlying value. - Value getUnderlyingValue() const { return underlyingValue; } + Value getUnderlyingValue() const { + assert(!isUninitialized()); + return *underlyingValue; + } /// Join two underlying values. If there are conflicting underlying values, /// go to the pessimistic value. static UnderlyingValue join(const UnderlyingValue &lhs, const UnderlyingValue &rhs) { - return lhs.underlyingValue == rhs.underlyingValue ? lhs : Value(); + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + return lhs.underlyingValue == rhs.underlyingValue + ? lhs + : UnderlyingValue(Value{}); } /// Compare underlying values. @@ -41,7 +54,7 @@ void print(raw_ostream &os) const { os << underlyingValue; } private: - Value underlyingValue; + Optional underlyingValue; }; /// This lattice represents, for a given memory resource, the potential last @@ -52,9 +65,6 @@ using AbstractDenseLattice::AbstractDenseLattice; - /// The lattice is always initialized. - bool isUninitialized() const override { return false; } - /// Clear all modifications. ChangeResult reset() { if (lastMods.empty()) @@ -169,7 +179,7 @@ const UnderlyingValueLattice *underlying; do { underlying = getUnderlyingValueFn(value); - if (!underlying || underlying->isUninitialized()) + if (!underlying || underlying->getValue().isUninitialized()) return {}; Value underlyingValue = underlying->getValue().getUnderlyingValue(); if (underlyingValue == value) diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp --- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp +++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp @@ -21,7 +21,7 @@ using AnalysisState::AnalysisState; /// Returns true if the state is uninitialized. - bool isUninitialized() const override { return !state; } + bool isUninitialized() const { return !state; } /// Print the integer value or "none" if uninitialized. void print(raw_ostream &os) const override { diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp --- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp +++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp @@ -25,7 +25,7 @@ OperationFolder &folder, Value value) { auto *maybeInferredRange = solver.lookupState(value); - if (!maybeInferredRange || maybeInferredRange->isUninitialized()) + if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) return failure(); const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue();