diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -0,0 +1,97 @@ +//===-IntegerRangeAnalysis.h - Integer range analysis -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the dataflow analysis class for integer range inference +// so that it can be used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H +#define MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" + +namespace mlir { +namespace dataflow { + +/// This lattice value represents the integer range of an SSA value. +class IntegerValueRange { +public: + /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) + /// range that is used to mark the value as unable to be analyzed further, + /// where `t` is the type of `value`. + static IntegerValueRange getPessimisticValueState(Value value); + + /// Create an integer value range lattice value. + IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {} + + /// Get the known integer value range. + const ConstantIntRanges &getValue() const { return value; } + + /// Compare two ranges. + bool operator==(const IntegerValueRange &rhs) const { + return value == rhs.value; + } + + /// Take the union of two ranges. + static IntegerValueRange join(const IntegerValueRange &lhs, + const IntegerValueRange &rhs) { + return lhs.value.rangeUnion(rhs.value); + } + + /// Print the integer value range. + void print(raw_ostream &os) const { os << value; } + +private: + /// The known integer value range. + ConstantIntRanges value; +}; + +/// This lattice element represents the integer value range of an SSA value. +/// When this lattice is updated, it automatically updates the constant value +/// of the SSA value (if the range can be narrowed to one). +class IntegerValueRangeLattice : public Lattice { +public: + using Lattice::Lattice; + + /// If the range can be narrowed to an integer constant, update the constant + /// value of the SSA value. + void onUpdate(DataFlowSolver *solver) const override; +}; + +/// Integer range analysis determines the integer value range of SSA values +/// using operations that define `InferIntRangeInterface` and also sets the +/// range of iteration indices of loops with known bounds. +class IntegerRangeAnalysis + : public SparseDataFlowAnalysis { +public: + using SparseDataFlowAnalysis::SparseDataFlowAnalysis; + + /// Visit an operation. Invoke the transfer function on each operation that + /// implements `InferIntRangeInterface`. + void visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; + + /// Visit block arguments or operation results of an operation with region + /// control-flow for which values are not defined by region control-flow. This + /// function calls `InferIntRangeInterface` to provide values for block + /// arguments or tries to reduce the range on loop induction variables with + /// known bounds. + void + visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, + ArrayRef argLattices, + unsigned firstIndex) override; +}; + +} // end namespace dataflow +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -211,6 +211,14 @@ ArrayRef operandLattices, ArrayRef resultLattices) = 0; + /// Given an operation with region control-flow, the lattices of the operands, + /// and a region successor, compute the lattice values for block arguments + /// that are not accounted for by the branching control flow (ex. the bounds + /// of loops). + virtual void visitNonControlFlowArgumentsImpl( + Operation *op, const RegionSuccessor &successor, + ArrayRef argLattices, unsigned firstIndex) = 0; + /// Get the lattice element of a value. virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; @@ -271,6 +279,21 @@ virtual void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) = 0; + /// Given an operation with possible region control-flow, the lattices of the + /// operands, and a region successor, compute the lattice values for block + /// arguments that are not accounted for by the branching control flow (ex. + /// the bounds of loops). By default, this method marks all such lattice + /// elements as having reached a pessimistic fixpoint. `firstIndex` is the + /// index of the first element of `argLattices` that is set by control-flow. + virtual void visitNonControlFlowArguments(Operation *op, + const RegionSuccessor &successor, + ArrayRef argLattices, + unsigned firstIndex) { + markAllPessimisticFixpoint(argLattices.take_front(firstIndex)); + markAllPessimisticFixpoint(argLattices.drop_front( + firstIndex + successor.getSuccessorInputs().size())); + } + protected: /// Get the lattice element for a value. StateT *getLatticeElement(Value value) override { @@ -305,6 +328,16 @@ {reinterpret_cast(resultLattices.begin()), resultLattices.size()}); } + void visitNonControlFlowArgumentsImpl( + Operation *op, const RegionSuccessor &successor, + ArrayRef argLattices, + unsigned firstIndex) override { + visitNonControlFlowArguments( + op, successor, + {reinterpret_cast(argLattices.begin()), + argLattices.size()}, + firstIndex); + } }; } // end namespace dataflow diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -226,7 +226,6 @@ /// Push a work item onto the worklist. void enqueue(WorkItem item) { worklist.push(std::move(item)); } -protected: /// Get the state associated with the given program point. If it does not /// exist, create an uninitialized state. template diff --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h deleted file mode 100644 --- a/mlir/include/mlir/Analysis/IntRangeAnalysis.h +++ /dev/null @@ -1,41 +0,0 @@ -//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the dataflow analysis class for integer range inference -// so that it can be used in transformations over the `arith` dialect such as -// branch elimination or signed->unsigned rewriting -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H -#define MLIR_ANALYSIS_INTRANGEANALYSIS_H - -#include "mlir/Interfaces/InferIntRangeInterface.h" - -namespace mlir { -namespace detail { -class IntRangeAnalysisImpl; -} // end namespace detail - -class IntRangeAnalysis { -public: - /// Analyze all operations rooted under (but not including) - /// `topLevelOperation`. - IntRangeAnalysis(Operation *topLevelOperation); - IntRangeAnalysis(IntRangeAnalysis &&other); - ~IntRangeAnalysis(); - - /// Get inferred range for value `v` if one exists. - Optional getResult(Value v); - -private: - std::unique_ptr impl; -}; -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td --- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td @@ -30,7 +30,7 @@ since the dataflow analysis handles those case), the method should call `setValueRange` with that `Value` as an argument. When `setValueRange` is not called for some value, it will recieve a default value of the mimimum - and maximum values forits type (the unbounded range). + and maximum values for its type (the unbounded range). When called on an op that also implements the RegionBranchOpInterface or BranchOpInterface, this method should not attempt to infer the values diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -4,7 +4,6 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp - IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -13,6 +12,7 @@ DataFlow/ConstantPropagationAnalysis.cpp DataFlow/DeadCodeAnalysis.cpp DataFlow/DenseAnalysis.cpp + DataFlow/IntegerRangeAnalysis.cpp DataFlow/SparseAnalysis.cpp ) @@ -23,7 +23,6 @@ DataFlowAnalysis.cpp DataFlowFramework.cpp DataLayoutAnalysis.cpp - IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -32,6 +31,7 @@ DataFlow/ConstantPropagationAnalysis.cpp DataFlow/DeadCodeAnalysis.cpp DataFlow/DenseAnalysis.cpp + DataFlow/IntegerRangeAnalysis.cpp DataFlow/SparseAnalysis.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -168,10 +168,19 @@ walkFn); } +/// Returns true if the operation terminates a block. It is insufficient to +/// check for `OpTrait::IsTerminator` because unregistered operations can be +/// terminators. +static bool isTerminator(Operation *op) { + if (op->hasTrait()) + return true; + return &op->getBlock()->back() == op; +} + LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { // Initialize the analysis by visiting every op with control-flow semantics. - if (op->getNumRegions() || op->getNumSuccessors() || - op->hasTrait() || isa(op)) { + if (op->getNumRegions() || op->getNumSuccessors() || isTerminator(op) || + isa(op)) { // When the liveness of the parent block changes, make sure to re-invoke the // analysis on the op. if (op->getBlock()) @@ -241,7 +250,7 @@ } } - if (op->hasTrait() && !op->getNumSuccessors()) { + if (isTerminator(op) && !op->getNumSuccessors()) { if (auto branch = dyn_cast(op->getParentOp())) { // Visit the exiting terminator of a region. visitRegionTerminator(op, branch); diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -0,0 +1,219 @@ +//===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; +using namespace mlir::dataflow; + +IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) { + unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType()); + APInt umin = APInt::getMinValue(width); + APInt umax = APInt::getMaxValue(width); + APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; + APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; + return {{umin, umax, smin, smax}}; +} + +void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { + Lattice::onUpdate(solver); + + // If the integer range can be narrowed to a constant, update the constant + // value of the SSA value. + Optional constant = getValue().getValue().getConstantValue(); + auto value = point.get(); + auto *cv = solver->getOrCreateState>(value); + if (!constant) + return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint()); + + Dialect *dialect; + if (auto *parent = value.getDefiningOp()) + dialect = parent->getDialect(); + else + dialect = value.getParentBlock()->getParentOp()->getDialect(); + solver->propagateIfChanged( + cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant), + dialect))); +} + +void IntegerRangeAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + // Ignore non-integer outputs - return early if the op has no scalar + // integer results + bool hasIntegerResult = false; + for (auto it : llvm::zip(results, op->getResults())) { + if (std::get<1>(it).getType().isIntOrIndex()) { + hasIntegerResult = true; + } else { + propagateIfChanged(std::get<0>(it), + std::get<0>(it)->markPessimisticFixpoint()); + } + } + if (!hasIntegerResult) + return; + + auto inferrable = dyn_cast(op); + if (!inferrable) + return markAllPessimisticFixpoint(results); + + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](const IntegerValueRangeLattice *val) { + return val->getValue().getValue(); + })); + + auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { + auto result = v.dyn_cast(); + if (!result) + return; + assert(llvm::find(op->getResults(), result) != op->result_end()); + + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; + Optional oldRange; + if (!lattice->isUninitialized()) + oldRange = lattice->getValue(); + + ChangeResult changed = lattice->join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedResult && oldRange.hasValue() && + !(lattice->getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + changed |= lattice->markPessimisticFixpoint(); + } + propagateIfChanged(lattice, changed); + }; + + inferrable.inferResultRanges(argRanges, joinCallback); +} + +void IntegerRangeAnalysis::visitNonControlFlowArguments( + Operation *op, const RegionSuccessor &successor, + ArrayRef argLattices, unsigned firstIndex) { + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); + SmallVector argRanges( + llvm::map_range(op->getOperands(), [&](Value value) { + return getLatticeElementFor(op, value)->getValue().getValue(); + })); + + auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { + auto arg = v.dyn_cast(); + if (!arg) + return; + if (llvm::find(successor.getSuccessor()->getArguments(), arg) == + successor.getSuccessor()->args_end()) + return; + + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; + Optional oldRange; + if (!lattice->isUninitialized()) + oldRange = lattice->getValue(); + + ChangeResult changed = lattice->join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + changed |= lattice->markPessimisticFixpoint(); + } + propagateIfChanged(lattice, changed); + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + return; + } + + /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() + /// on a LoopLikeInterface return the lower/upper bound for that result if + /// possible. + auto getLoopBoundFromFold = [&](Optional loopBound, + Type boundType, bool getUpper) { + unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); + if (loopBound.hasValue()) { + if (loopBound->is()) { + if (auto bound = + loopBound->get().dyn_cast_or_null()) + return bound.getValue(); + } else if (auto value = loopBound->dyn_cast()) { + const IntegerValueRangeLattice *lattice = + getLatticeElementFor(op, value); + if (lattice != nullptr) + return getUpper ? lattice->getValue().getValue().smax() + : lattice->getValue().getValue().smin(); + } + } + // Given the results of getConstant{Lower,Upper}Bound() + // or getConstantStep() on a LoopLikeInterface return the lower/upper + // bound + return getUpper ? APInt::getSignedMaxValue(width) + : APInt::getSignedMinValue(width); + }; + + // Infer bounds for loop arguments that have static bounds + if (auto loop = dyn_cast(op)) { + Optional iv = loop.getSingleInductionVar(); + if (!iv) { + return SparseDataFlowAnalysis ::visitNonControlFlowArguments( + op, successor, argLattices, firstIndex); + } + Optional lowerBound = loop.getSingleLowerBound(); + Optional upperBound = loop.getSingleUpperBound(); + Optional step = loop.getSingleStep(); + APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), + /*getUpper=*/false); + APInt max = getLoopBoundFromFold(upperBound, iv->getType(), + /*getUpper=*/true); + // Assume positivity for uniscoverable steps by way of getUpper = true. + APInt stepVal = + getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true); + + if (stepVal.isNegative()) { + std::swap(min, max); + } else { + // Correct the upper bound by subtracting 1 so that it becomes a <= + // bound, because loops do not generally include their upper bound. + max -= 1; + } + + IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); + auto ivRange = ConstantIntRanges::fromSigned(min, max); + propagateIfChanged(ivEntry, ivEntry->join(ivRange)); + return; + } + + return SparseDataFlowAnalysis::visitNonControlFlowArguments( + op, successor, argLattices, firstIndex); +} diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -181,7 +181,9 @@ } // Otherwise, we can't reason about the data-flow. - return markAllPessimisticFixpoint(argLattices); + return visitNonControlFlowArgumentsImpl(block->getParentOp(), + RegionSuccessor(block->getParent()), + argLattices, /*firstIndex=*/0); } // Iterate over the predecessors of the non-entry block. @@ -234,7 +236,6 @@ operands = branch.getSuccessorEntryOperands(successorIndex); // Otherwise, try to deduce the operands from a region return-like op. } else { - assert(op->hasTrait() && "expected a terminator"); if (isRegionReturnLike(op)) operands = getRegionBranchSuccessorOperands(op, successorIndex); } @@ -248,17 +249,26 @@ assert(inputs.size() == operands->size() && "expected the same number of successor inputs as operands"); - // TODO: This was updated to be exposed upstream. unsigned firstIndex = 0; if (inputs.size() != lattices.size()) { - if (inputs.empty()) { - markAllPessimisticFixpoint(lattices); - return; + if (auto *op = point.dyn_cast()) { + if (!inputs.empty()) + firstIndex = inputs.front().cast().getResultNumber(); + visitNonControlFlowArgumentsImpl( + branch, + RegionSuccessor( + branch->getResults().slice(firstIndex, inputs.size())), + lattices, firstIndex); + } else { + if (!inputs.empty()) + firstIndex = inputs.front().cast().getArgNumber(); + Region *region = point.get()->getParent(); + visitNonControlFlowArgumentsImpl( + branch, + RegionSuccessor(region, region->getArguments().slice( + firstIndex, inputs.size())), + lattices, firstIndex); } - firstIndex = inputs.front().cast().getArgNumber(); - markAllPessimisticFixpoint(lattices.take_front(firstIndex)); - markAllPessimisticFixpoint( - lattices.drop_front(firstIndex + inputs.size())); } for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -87,19 +87,6 @@ return failure(); } - // "Nudge" the state of the analysis by forcefully initializing states that - // are still uninitialized. All uninitialized states in the graph can be - // initialized in any order because the analysis reached fixpoint, meaning - // that there are no work items that would have further nudged the analysis. - for (AnalysisState &state : - llvm::make_pointee_range(llvm::make_second_range(analysisStates))) { - if (!state.isUninitialized()) - continue; - DATAFLOW_DEBUG(llvm::dbgs() << "Default initializing " << state.debugName - << " of " << state.point << "\n"); - propagateIfChanged(&state, state.defaultInitialize()); - } - // Iterate until all states are in some initialized state and the worklist // is exhausted. } while (!worklist.empty()); diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp deleted file mode 100644 --- a/mlir/lib/Analysis/IntRangeAnalysis.cpp +++ /dev/null @@ -1,335 +0,0 @@ -//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the dataflow analysis class for integer range inference -// which is used in transformations over the `arith` dialect such as -// branch elimination or signed->unsigned rewriting -// -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/IntRangeAnalysis.h" -#include "mlir/Analysis/DataFlowAnalysis.h" -#include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "int-range-analysis" - -using namespace mlir; - -namespace { -/// A wrapper around ConstantIntRanges that provides the lattice functions -/// expected by dataflow analysis. -struct IntRangeLattice { - IntRangeLattice(const ConstantIntRanges &value) : value(value){}; - IntRangeLattice(ConstantIntRanges &&value) : value(value){}; - - bool operator==(const IntRangeLattice &other) const { - return value == other.value; - } - - /// wrapper around rangeUnion() - static IntRangeLattice join(const IntRangeLattice &a, - const IntRangeLattice &b) { - return a.value.rangeUnion(b.value); - } - - /// Creates a range with bitwidth 0 to represent that we don't know if the - /// value being marked overdefined is even an integer. - static IntRangeLattice getPessimisticValueState(MLIRContext *context) { - APInt noIntValue = APInt::getZeroWidth(); - return ConstantIntRanges(noIntValue, noIntValue, noIntValue, noIntValue); - } - - /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) - /// range that is used to mark the value v as unable to be analyzed further, - /// where t is the type of v. - static IntRangeLattice getPessimisticValueState(Value v) { - unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType()); - APInt umin = APInt::getMinValue(width); - APInt umax = APInt::getMaxValue(width); - APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; - APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; - return ConstantIntRanges{umin, umax, smin, smax}; - } - - ConstantIntRanges value; -}; -} // end anonymous namespace - -namespace mlir { -namespace detail { -class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis { - using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; - -public: - /// Define bounds on the results or block arguments of the operation - /// based on the bounds on the arguments given in `operands` - ChangeResult - visitOperation(Operation *op, - ArrayRef *> operands) final; - - /// Skip regions of branch ops when we can statically infer constant - /// values for operands to the branch op and said op tells us it's safe to do - /// so. - LogicalResult - getSuccessorsForOperands(BranchOpInterface branch, - ArrayRef *> operands, - SmallVectorImpl &successors) final; - - /// Skip regions of branch or loop ops when we can statically infer constant - /// values for operands to the branch op and said op tells us it's safe to do - /// so. - void - getSuccessorsForOperands(RegionBranchOpInterface branch, - Optional sourceIndex, - ArrayRef *> operands, - SmallVectorImpl &successors) final; - - /// Call the InferIntRangeInterface implementation for region-using ops - /// that implement it, and infer the bounds of loop induction variables - /// for ops that implement LoopLikeOPInterface. - ChangeResult visitNonControlFlowArguments( - Operation *op, const RegionSuccessor ®ion, - ArrayRef *> operands) final; -}; -} // end namespace detail -} // end namespace mlir - -/// Given the results of getConstant{Lower,Upper}Bound() -/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for -/// that result if possible. -static APInt getLoopBoundFromFold(Optional loopBound, - Type boundType, - detail::IntRangeAnalysisImpl &analysis, - bool getUpper) { - unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); - if (loopBound) { - if (loopBound->is()) { - if (auto bound = - loopBound->get().dyn_cast_or_null()) - return bound.getValue(); - } else if (loopBound->is()) { - LatticeElement *lattice = - analysis.lookupLatticeElement(loopBound->get()); - if (lattice != nullptr) - return getUpper ? lattice->getValue().value.smax() - : lattice->getValue().value.smin(); - } - } - return getUpper ? APInt::getSignedMaxValue(width) - : APInt::getSignedMinValue(width); -} - -ChangeResult detail::IntRangeAnalysisImpl::visitOperation( - Operation *op, ArrayRef *> operands) { - ChangeResult result = ChangeResult::NoChange; - // Ignore non-integer outputs - return early if the op has no scalar - // integer results - bool hasIntegerResult = false; - for (Value v : op->getResults()) { - if (v.getType().isIntOrIndex()) - hasIntegerResult = true; - else - result |= markAllPessimisticFixpoint(v); - } - if (!hasIntegerResult) - return result; - - if (auto inferrable = dyn_cast(op)) { - LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); - LLVM_DEBUG(inferrable->print(llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - SmallVector argRanges( - llvm::map_range(operands, [](LatticeElement *val) { - return val->getValue().value; - })); - - auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { - LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); - LatticeElement &lattice = getLatticeElement(v); - Optional oldRange; - if (!lattice.isUninitialized()) - oldRange = lattice.getValue(); - result |= lattice.join(IntRangeLattice(attrs)); - - // Catch loop results with loop variant bounds and conservatively make - // them [-inf, inf] so we don't circle around infinitely often (because - // the dataflow analysis in MLIR doesn't attempt to work out trip counts - // and often can't). - bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { - return op->hasTrait(); - }); - if (isYieldedResult && oldRange && !(lattice.getValue() == *oldRange)) { - LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); - result |= lattice.markPessimisticFixpoint(); - } - }; - - inferrable.inferResultRanges(argRanges, joinCallback); - for (Value opResult : op->getResults()) { - LatticeElement &lattice = getLatticeElement(opResult); - // setResultRange() not called, make pessimistic. - if (lattice.isUninitialized()) - result |= lattice.markPessimisticFixpoint(); - } - } else if (op->getNumRegions() == 0) { - // No regions + no result inference method -> unbounded results (ex. memory - // ops) - result |= markAllPessimisticFixpoint(op->getResults()); - } - return result; -} - -LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands( - BranchOpInterface branch, - ArrayRef *> operands, - SmallVectorImpl &successors) { - auto toConstantAttr = [&branch](auto enumPair) -> Attribute { - Optional maybeConstValue = - enumPair.value()->getValue().value.getConstantValue(); - - if (maybeConstValue) { - return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), - *maybeConstValue); - } - return {}; - }; - SmallVector inferredConsts( - llvm::map_range(llvm::enumerate(operands), toConstantAttr)); - if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) { - successors.push_back(singleSucc); - return success(); - } - return failure(); -} - -void detail::IntRangeAnalysisImpl::getSuccessorsForOperands( - RegionBranchOpInterface branch, Optional sourceIndex, - ArrayRef *> operands, - SmallVectorImpl &successors) { - // Get a type with which to construct a constant. - auto getOperandType = [branch, sourceIndex](unsigned index) { - // The types of all return-like operations are the same. - if (!sourceIndex) - return branch->getOperand(index).getType(); - - for (Block &block : branch->getRegion(*sourceIndex)) { - Operation *terminator = block.getTerminator(); - if (getRegionBranchSuccessorOperands(terminator, *sourceIndex)) - return terminator->getOperand(index).getType(); - } - return Type(); - }; - - auto toConstantAttr = [&getOperandType](auto enumPair) -> Attribute { - if (Optional maybeConstValue = - enumPair.value()->getValue().value.getConstantValue()) { - return IntegerAttr::get(getOperandType(enumPair.index()), - *maybeConstValue); - } - return {}; - }; - SmallVector inferredConsts( - llvm::map_range(llvm::enumerate(operands), toConstantAttr)); - branch.getSuccessorRegions(sourceIndex, inferredConsts, successors); -} - -ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments( - Operation *op, const RegionSuccessor ®ion, - ArrayRef *> operands) { - if (auto inferrable = dyn_cast(op)) { - LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); - LLVM_DEBUG(inferrable->print(llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - SmallVector argRanges( - llvm::map_range(operands, [](LatticeElement *val) { - return val->getValue().value; - })); - - ChangeResult result = ChangeResult::NoChange; - auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { - LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); - LatticeElement &lattice = getLatticeElement(v); - Optional oldRange; - if (!lattice.isUninitialized()) - oldRange = lattice.getValue(); - result |= lattice.join(IntRangeLattice(attrs)); - - // Catch loop results with loop variant bounds and conservatively make - // them [-inf, inf] so we don't circle around infinitely often (because - // the dataflow analysis in MLIR doesn't attempt to work out trip counts - // and often can't). - bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { - return op->hasTrait(); - }); - if (isYieldedValue && oldRange && !(lattice.getValue() == *oldRange)) { - LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); - result |= lattice.markPessimisticFixpoint(); - } - }; - - inferrable.inferResultRanges(argRanges, joinCallback); - for (Value regionArg : region.getSuccessor()->getArguments()) { - LatticeElement &lattice = getLatticeElement(regionArg); - // setResultRange() not called, make pessimistic. - if (lattice.isUninitialized()) - result |= lattice.markPessimisticFixpoint(); - } - - return result; - } - - // Infer bounds for loop arguments that have static bounds - if (auto loop = dyn_cast(op)) { - Optional iv = loop.getSingleInductionVar(); - if (!iv) { - return ForwardDataFlowAnalysis< - IntRangeLattice>::visitNonControlFlowArguments(op, region, operands); - } - Optional lowerBound = loop.getSingleLowerBound(); - Optional upperBound = loop.getSingleUpperBound(); - Optional step = loop.getSingleStep(); - APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this, - /*getUpper=*/false); - APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this, - /*getUpper=*/true); - // Assume positivity for uniscoverable steps by way of getUpper = true. - APInt stepVal = - getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true); - - if (stepVal.isNegative()) { - std::swap(min, max); - } else { - // Correct the upper bound by subtracting 1 so that it becomes a <= bound, - // because loops do not generally include their upper bound. - max -= 1; - } - - LatticeElement &ivEntry = getLatticeElement(*iv); - return ivEntry.join(ConstantIntRanges::fromSigned(min, max)); - } - return ForwardDataFlowAnalysis::visitNonControlFlowArguments( - op, region, operands); -} - -IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) { - impl = std::make_unique( - topLevelOperation->getContext()); - impl->run(topLevelOperation); -} - -IntRangeAnalysis::~IntRangeAnalysis() = default; -IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default; - -Optional IntRangeAnalysis::getResult(Value v) { - LatticeElement *result = impl->lookupLatticeElement(v); - if (result == nullptr || result->isUninitialized()) - return llvm::None; - return result->getValue().value; -} diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp @@ -9,33 +9,34 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::arith; +using namespace mlir::dataflow; /// Succeeds when a value is statically non-negative in that it has a lower /// bound on its value (if it is treated as signed) and that bound is /// non-negative. -static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis, - Value v) { - Optional result = analysis.getResult(v); - if (!result.hasValue()) +static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { + auto *result = solver.lookupState(v); + if (!result) return failure(); - const ConstantIntRanges &range = result.getValue(); + const ConstantIntRanges &range = result->getValue().getValue(); return success(range.smin().isNonNegative()); } /// Succeeds if an op can be converted to its unsigned equivalent without /// changing its semantics. This is the case when none of its openands or /// results can be below 0 when analyzed from a signed perspective. -static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis, +static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) { - auto nonNegativePred = [&analysis](Value v) -> bool { - return succeeded(staticallyNonNegative(analysis, v)); + auto nonNegativePred = [&solver](Value v) -> bool { + return succeeded(staticallyNonNegative(solver, v)); }; return success(llvm::all_of(op->getOperands(), nonNegativePred) && llvm::all_of(op->getResults(), nonNegativePred)); @@ -44,15 +45,15 @@ /// Succeeds when the comparison predicate is a signed operation and all the /// operands are non-negative, indicating that the cmpi operation `op` can have /// its predicate changed to an unsigned equivalent. -static LogicalResult isCmpIConvertable(IntRangeAnalysis &analysis, CmpIOp op) { +static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) { CmpIPredicate pred = op.getPredicate(); switch (pred) { case CmpIPredicate::sle: case CmpIPredicate::slt: case CmpIPredicate::sge: case CmpIPredicate::sgt: - return success(llvm::all_of(op.getOperands(), [&analysis](Value v) -> bool { - return succeeded(staticallyNonNegative(analysis, v)); + return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool { + return succeeded(staticallyNonNegative(solver, v)); })); default: return failure(); @@ -109,19 +110,23 @@ void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); - IntRangeAnalysis analysis(op); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); ConversionTarget target(*ctx); target.addLegalDialect(); target .addDynamicallyLegalOp( - [&analysis](Operation *op) -> Optional { - return failed(staticallyNonNegative(analysis, op)); + [&solver](Operation *op) -> Optional { + return failed(staticallyNonNegative(solver, op)); }); target.addDynamicallyLegalOp( - [&analysis](CmpIOp op) -> Optional { - return failed(isCmpIConvertable(analysis, op)); + [&solver](CmpIOp op) -> Optional { + return failed(isCmpIConvertable(solver, op)); }); RewritePatternSet patterns(ctx); diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -38,7 +38,7 @@ OpBuilder &builder, OperationFolder &folder, Value value) { auto *lattice = solver.lookupState>(value); - if (!lattice) + if (!lattice || lattice->isUninitialized()) return failure(); const ConstantValue &latticeValue = lattice->getValue(); if (!latticeValue.getConstantValue()) diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp @@ -68,9 +68,8 @@ LogicalResult initialize(Operation *top) override { WalkResult result = top->walk([&](Operation *op) { - if (op->hasTrait()) - if (failed(visit(op))) - return WalkResult::interrupt(); + if (failed(visit(op))) + return WalkResult::interrupt(); return WalkResult::advance(); }); return success(!result.wasInterrupted()); @@ -83,13 +82,27 @@ auto *constant = getOrCreate>(op->getResult(0)); propagateIfChanged( constant, constant->join(ConstantValue(value, op->getDialect()))); + return success(); } + markAllPessimisticFixpoint(op->getResults()); + for (Region ®ion : op->getRegions()) + markAllPessimisticFixpoint(region.getArguments()); return success(); } + + /// Mark the constant values of all given values as having reached a + /// pessimistic fixpoint. + void markAllPessimisticFixpoint(ValueRange values) { + for (Value value : values) { + auto *constantValue = getOrCreate>(value); + propagateIfChanged(constantValue, + constantValue->markPessimisticFixpoint()); + } + } }; -/// This is a simple pass that runs dead code analysis with no constant value -/// provider. It marks everything as live. +/// This is a simple pass that runs dead code analysis with a constant value +/// provider that only understands constant operations. struct TestDeadCodeAnalysisPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass) diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp --- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp +++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp @@ -9,7 +9,8 @@ // functionality has been integrated into SCCP. //===----------------------------------------------------------------------===// -#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" @@ -17,15 +18,17 @@ #include "mlir/Transforms/FoldUtils.h" using namespace mlir; +using namespace mlir::dataflow; /// Patterned after SCCP -static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, - OpBuilder &b, OperationFolder &folder, - Value value) { - Optional maybeInferredRange = analysis.getResult(value); - if (!maybeInferredRange) +static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b, + OperationFolder &folder, Value value) { + auto *maybeInferredRange = + solver.lookupState(value); + if (!maybeInferredRange || maybeInferredRange->isUninitialized()) return failure(); - const ConstantIntRanges &inferredRange = maybeInferredRange.getValue(); + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); Optional maybeConstValue = inferredRange.getConstantValue(); if (!maybeConstValue.hasValue()) return failure(); @@ -44,7 +47,7 @@ return success(); } -static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, +static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef initialRegions) { SmallVector worklist; auto addToWorklist = [&](MutableArrayRef regions) { @@ -67,7 +70,7 @@ bool replacedAll = op.getNumResults() != 0; for (Value res : op.getResults()) replacedAll &= - succeeded(replaceWithConstant(analysis, builder, folder, res)); + succeeded(replaceWithConstant(solver, builder, folder, res)); // If all of the results of the operation were replaced, try to erase // the operation completely. @@ -84,7 +87,7 @@ // Replace any block arguments with constants. builder.setInsertionPointToStart(block); for (BlockArgument arg : block->getArguments()) - (void)replaceWithConstant(analysis, builder, folder, arg); + (void)replaceWithConstant(solver, builder, folder, arg); } } @@ -100,8 +103,12 @@ void runOnOperation() override { Operation *op = getOperation(); - IntRangeAnalysis analysis(op); - rewrite(analysis, op->getContext(), op->getRegions()); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + rewrite(solver, op->getContext(), op->getRegions()); } }; } // end anonymous namespace