diff --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/IntRangeAnalysis.h @@ -0,0 +1,41 @@ +//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the dataflow analysis class for integer range inference +// so that it can be used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H +#define MLIR_ANALYSIS_INTRANGEANALYSIS_H + +#include "mlir/Interfaces/InferIntRangeInterface.h" + +namespace mlir { +namespace detail { +class IntRangeAnalysisImpl; +} // end namespace detail + +class IntRangeAnalysis { +public: + /// Analyze all operations rooted under (but not including) + /// `topLevelOperation`. + IntRangeAnalysis(Operation *topLevelOperation); + IntRangeAnalysis(IntRangeAnalysis &&other); + ~IntRangeAnalysis(); + + /// Get inferred range for value `v` if one exists. + Optional getResult(Value v); + +private: + std::unique_ptr impl; +}; +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_interface(ControlFlowInterfaces) add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) +add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -0,0 +1,98 @@ +//===- InferIntRangeInterface.h - Integer Range Inference --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the integer range inference interface +// defined in `InferIntRange.td` +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H +#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +/// A set of arbitrary-precision integers representing bounds on a given integer +/// value. These bounds are inclusive on both ends, so +/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for +/// the unsigned and signed interpretations of values in order to enable more +/// precice inference of the interplay between operations with signed and +/// unsigned semantics. +class ConstantIntRanges { +public: + /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax. + /// Non-integer values should be bounded by APInts of bitwidth 0. + ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin, + const APInt &smax) + : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) { + assert(uminVal.getBitWidth() == umaxVal.getBitWidth() && + umaxVal.getBitWidth() == sminVal.getBitWidth() && + sminVal.getBitWidth() == smaxVal.getBitWidth() && + "All bounds in the ranges must have the same bitwidth"); + } + + bool operator==(const ConstantIntRanges &other) const; + + /// The minimum value of an integer when it is interpreted as unsigned. + const APInt &umin() const; + + /// The maximum value of an integer when it is interpreted as unsigned. + const APInt &umax() const; + + /// The minimum value of an integer when it is interpreted as signed. + const APInt &smin() const; + + /// The maximum value of an integer when it is interpreted as signed. + const APInt &smax() const; + + /// Return the bitwidth that should be used for integer ranges describing + /// `type`. For concrete integer types, this is their bitwidth, for `index`, + /// this is the internal storage bitwidth of `index` attributes, and for + /// non-integer types this is 0. + static unsigned getStorageBitwidth(Type type); + + /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned + /// minimum and `max` is both the signed and unsigned maximum. + static ConstantIntRanges range(const APInt &min, const APInt &max); + + /// Create an `IntRangeAttrs` with the signed minimum and maximum equal + /// to `smin` and `smax`, where the unsigned bounds are constructed from the + /// signed ones if they correspond to a contigious range of bit patterns when + /// viewed as unsigned values and are left at [0, int_max()] otherwise. + static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax); + + /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal + /// to `umin` and `umax` and the signed part equal to `umin` and `umax` + /// unless the sign bit changes between the minimum and maximum. + static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax); + + /// Returns the union (computed separately for signed and unsigned bounds) + /// of `a` and `b`. + ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const; + + /// If either the signed or unsigned interpretations of the range + /// indicate that the value it bounds is a constant, return that constant + /// value. + Optional getConstantValue() const; + + friend raw_ostream &operator<<(raw_ostream &os, + const ConstantIntRanges &range); + +private: + APInt uminVal, umaxVal, sminVal, smaxVal; +}; + +/// The type of the `setResultRanges` callback provided to ops implementing +/// InferIntRangeInterface. It should be called once for each integer result +/// value and be passed the ConstantIntRanges corresponding to that value. +using SetIntRangeFn = function_ref; +} // end namespace mlir + +#include "mlir/Interfaces/InferIntRangeInterface.h.inc" + +#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td @@ -0,0 +1,52 @@ +//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------===// +// +// Defines the interface for range analysis on scalar integers +// +//===-----------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE +#define MLIR_INTERFACES_INFERINTRANGEINTERFACE + +include "mlir/IR/OpBase.td" + +def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> { + let description = [{ + Allows operations to participate in range analysis for scalar integer values by + providing a methods that allows them to specify lower and upper bounds on their + result(s) given lower and upper bounds on their input(s) if known. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Infer the bounds on the results of this op given the bounds on its arguments. + For each result value or block argument (that isn't a branch argument, + since the dataflow analysis handles those case), the method should call + `setValueRange` with that `Value` as an argument. When `setValueRange` + is not called for some value, it will recieve a default value of the mimimum + and maximum values forits type (the unbounded range). + + When called on an op that also implements the RegionBranchOpInterface + or BranchOpInterface, this method should not attempt to infer the values + of the branch results, as this will be handled by the analyses that use + this interface. + + This function will only be called when at least one result of the op is a + scalar integer value or the op has a region. + + `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS + order. Non-integer arguments will have the an unbounded range of width-0 + APInts in their `argRanges` element. + }], + "void", "inferResultRanges", (ins + "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges, + "::mlir::SetIntRangeFn":$setResultRanges) + >]; +} +#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp + IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -16,6 +17,7 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp + IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -31,7 +33,9 @@ MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface + MLIRLoopLikeInterface MLIRSideEffectInterfaces MLIRViewLikeInterface ) diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp --- a/mlir/lib/Analysis/DataFlowAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -359,11 +359,20 @@ if (auto branch = dyn_cast(op)) return visitRegionBranchOperation(branch, operandLattices); - // If we can't, conservatively mark all regions as executable. - // TODO: Let the `visitOperation` method decide how to propagate - // information to the block arguments. - for (Region ®ion : op->getRegions()) - markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); + for (Region ®ion : op->getRegions()) { + analysis.visitNonControlFlowArguments(op, RegionSuccessor(®ion), + operandLattices); + // `visitNonControlFlowArguments` is required to define all of the region + // argument lattices. + assert(llvm::none_of( + region.getArguments(), + [&](Value value) { + return analysis.getLatticeElement(value).isUninitialized(); + }) && + "expected `visitNonControlFlowArguments` to define all argument " + "lattices"); + markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/false); + } } // If this op produces no results, it can't produce any constants. @@ -567,12 +576,45 @@ if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) return; + // If the branch is a RegionBranchTerminatorOpInterface, + // construct the set of operand lattices as the set of non control-flow + // arguments of the parent and the values this op returns. This allows + // for the correct lattices to be passed to getSuccessorsForOperands() + // in cases such as scf.while. + ArrayRef branchOpLattices = operandLattices; + SmallVector parentLattices; + if (auto regionTerminator = + dyn_cast(op)) { + parentLattices.reserve(regionInterface->getNumOperands()); + for (Value parentOperand : regionInterface->getOperands()) { + AbstractLatticeElement *operandLattice = + analysis.lookupLatticeElement(parentOperand); + if (!operandLattice || operandLattice->isUninitialized()) + return; + parentLattices.push_back(operandLattice); + } + unsigned regionNumber = parentRegion->getRegionNumber(); + OperandRange iterArgs = + regionInterface.getSuccessorEntryOperands(regionNumber); + OperandRange terminatorArgs = + regionTerminator.getSuccessorOperands(regionNumber); + assert(iterArgs.size() == terminatorArgs.size() && + "Number of iteration arguments for region should equal number of " + "those arguments defined by terminator"); + if (!iterArgs.empty()) { + unsigned iterStart = iterArgs.getBeginOperandIndex(); + unsigned terminatorStart = terminatorArgs.getBeginOperandIndex(); + for (unsigned i = 0, e = iterArgs.size(); i < e; ++i) + parentLattices[iterStart + i] = operandLattices[terminatorStart + i]; + } + branchOpLattices = parentLattices; + } // Query the set of successors of the current region using the current // optimistic lattice state. SmallVector regionSuccessors; analysis.getSuccessorsForOperands(regionInterface, parentRegion->getRegionNumber(), - operandLattices, regionSuccessors); + branchOpLattices, regionSuccessors); if (regionSuccessors.empty()) return; @@ -584,7 +626,7 @@ // region index (if any). return *getRegionBranchSuccessorOperands(op, regionIndex); }; - return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices, + return visitRegionSuccessors(parentOp, regionSuccessors, branchOpLattices, getOperands); } diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp @@ -0,0 +1,325 @@ +//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; + +namespace { +/// A wrapper around ConstantIntRanges that provides the lattice functions +/// expected by dataflow analysis. +struct IntRangeLattice { + IntRangeLattice(const ConstantIntRanges &value) : value(value){}; + IntRangeLattice(ConstantIntRanges &&value) : value(value){}; + + bool operator==(const IntRangeLattice &other) const { + return value == other.value; + } + + /// wrapper around rangeUnion() + static IntRangeLattice join(const IntRangeLattice &a, + const IntRangeLattice &b) { + return a.value.rangeUnion(b.value); + } + + /// Creates a range with bitwidth 0 to represent that we don't know if the + /// value being marked overdefined is even an integer. + static IntRangeLattice getPessimisticValueState(MLIRContext *context) { + APInt noIntValue = APInt::getZeroWidth(); + return ConstantIntRanges::range(noIntValue, noIntValue); + } + + /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) + /// range that is used to mark the value v as unable to be analyzed further, + /// where t is the type of v. + static IntRangeLattice getPessimisticValueState(Value v) { + unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType()); + APInt umin = APInt::getMinValue(width); + APInt umax = APInt::getMaxValue(width); + APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; + APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; + return ConstantIntRanges{umin, umax, smin, smax}; + } + + ConstantIntRanges value; +}; +} // end anonymous namespace + +namespace mlir { +namespace detail { +class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + +public: + /// Define bounds on the results or block arguments of the operation + /// based on the bounds on the arguments given in `operands` + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) final; + + /// Skip regions of branch ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + LogicalResult + getSuccessorsForOperands(BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Skip regions of branch or loop ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + void + getSuccessorsForOperands(RegionBranchOpInterface branch, + Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Call the InferIntRangeInterface implementation for region-using ops + /// that implement it, and infer the bounds of loop induction variables + /// for ops that implement LoopLikeOPInterface. + ChangeResult visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) final; +}; +} // end namespace detail +} // end namespace mlir + +/// Given the results of getConstant{Lower,Upper}Bound() +/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for +/// that result if possible. +static APInt getLoopBoundFromFold(Optional loopBound, + Type boundType, + detail::IntRangeAnalysisImpl &analysis, + bool getUpper) { + unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); + if (loopBound.hasValue()) { + if (loopBound->is()) { + if (auto bound = + loopBound->get().dyn_cast_or_null()) + return bound.getValue(); + } else if (loopBound->is()) { + LatticeElement *lattice = + analysis.lookupLatticeElement(loopBound->get()); + if (lattice != nullptr) + return getUpper ? lattice->getValue().value.smax() + : lattice->getValue().value.smin(); + } + } + return getUpper ? APInt::getSignedMaxValue(width) + : APInt::getSignedMinValue(width); +} + +ChangeResult detail::IntRangeAnalysisImpl::visitOperation( + Operation *op, ArrayRef *> operands) { + ChangeResult result = ChangeResult::NoChange; + // Ignore non-integer outputs - return early if the op has no scalar + // integer results + bool hasIntegerResult = false; + for (Value v : op->getResults()) { + if (v.getType().isIntOrIndex()) + hasIntegerResult = true; + else + result |= markAllPessimisticFixpoint(v); + } + if (!hasIntegerResult) + return result; + + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue().value; + })); + + auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LatticeElement &lattice = getLatticeElement(v); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + result |= lattice.join(IntRangeLattice(attrs)); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedResult && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + result |= lattice.markPessimisticFixpoint(); + } + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + for (Value opResult : op->getResults()) { + LatticeElement &lattice = getLatticeElement(opResult); + // setResultRange() not called, make pessimistic. + if (lattice.isUninitialized()) + result |= lattice.markPessimisticFixpoint(); + } + } else if (op->getNumRegions() == 0) { + // No regions + no result inference method -> unbounded results (ex. memory + // ops) + result |= markAllPessimisticFixpoint(op->getResults()); + } + return result; +} + +LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands( + BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) { + auto toConstantAttr = [&branch](auto enumPair) -> Attribute { + Optional maybeConstValue = + enumPair.value()->getValue().value.getConstantValue(); + + if (maybeConstValue) { + return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + *maybeConstValue); + } + return {}; + }; + SmallVector inferredConsts( + llvm::map_range(llvm::enumerate(operands), toConstantAttr)); + if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) { + successors.push_back(singleSucc); + return success(); + } + return failure(); +} + +void detail::IntRangeAnalysisImpl::getSuccessorsForOperands( + RegionBranchOpInterface branch, Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) { + auto toConstantAttr = [&branch](auto enumPair) -> Attribute { + Optional maybeConstValue = + enumPair.value()->getValue().value.getConstantValue(); + + if (maybeConstValue) { + return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + *maybeConstValue); + } + return {}; + }; + SmallVector inferredConsts( + llvm::map_range(llvm::enumerate(operands), toConstantAttr)); + branch.getSuccessorRegions(sourceIndex, inferredConsts, successors); +} + +ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) { + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue().value; + })); + + ChangeResult result = ChangeResult::NoChange; + auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LatticeElement &lattice = getLatticeElement(v); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + result |= lattice.join(IntRangeLattice(attrs)); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedValue && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + result |= lattice.markPessimisticFixpoint(); + } + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + for (Value regionArg : region.getSuccessor()->getArguments()) { + LatticeElement &lattice = getLatticeElement(regionArg); + // setResultRange() not called, make pessimistic. + if (lattice.isUninitialized()) + result |= lattice.markPessimisticFixpoint(); + } + + return result; + } + + // Infer bounds for loop arguments that have static bounds + if (auto loop = dyn_cast(op)) { + Optional iv = loop.getSingleInductionVar(); + if (!iv.hasValue()) { + return ForwardDataFlowAnalysis< + IntRangeLattice>::visitNonControlFlowArguments(op, region, operands); + } + Optional lowerBound = loop.getSingleLowerBound(); + Optional upperBound = loop.getSingleUpperBound(); + Optional step = loop.getSingleStep(); + APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this, + /*getUpper=*/false); + APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this, + /*getUpper=*/true); + // Assume positivity for uniscoverable steps by way of getUpper = true. + APInt stepVal = + getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true); + + if (stepVal.isNegative()) { + std::swap(min, max); + } else { + // Correct the upper bound by subtracting 1 so that it becomes a <= bound, + // because loops do not generally include their upper bound. + max -= 1; + } + + LatticeElement &ivEntry = getLatticeElement(*iv); + return ivEntry.join(ConstantIntRanges::fromSigned(min, max)); + } + return ForwardDataFlowAnalysis::visitNonControlFlowArguments( + op, region, operands); +} + +IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) { + impl = std::make_unique( + topLevelOperation->getContext()); + impl->run(topLevelOperation); +} + +IntRangeAnalysis::~IntRangeAnalysis() = default; +IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default; + +Optional IntRangeAnalysis::getResult(Value v) { + LatticeElement *result = impl->lookupLatticeElement(v); + if (result == nullptr || result->isUninitialized()) + return llvm::None; + return result->getValue().value; +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ CopyOpInterface.cpp DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp + InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp @@ -35,6 +36,7 @@ add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) +add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -0,0 +1,99 @@ +//===- InferIntRangeInterface.cpp - Integer range inference interface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc" + +using namespace mlir; + +bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const { + return umin().getBitWidth() == other.umin().getBitWidth() && + umin() == other.umin() && umax() == other.umax() && + smin() == other.smin() && smax() == other.smax(); +} + +const APInt &ConstantIntRanges::umin() const { return uminVal; } + +const APInt &ConstantIntRanges::umax() const { return umaxVal; } + +const APInt &ConstantIntRanges::smin() const { return sminVal; } + +const APInt &ConstantIntRanges::smax() const { return smaxVal; } + +unsigned ConstantIntRanges::getStorageBitwidth(Type type) { + if (type.isIndex()) + return IndexType::kInternalStorageBitWidth; + if (auto integerType = type.dyn_cast()) + return integerType.getWidth(); + // Non-integer types have their bounds stored in width 0 `APInt`s. + return 0; +} + +ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) { + return {min, max, min, max}; +} + +ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin, + const APInt &smax) { + unsigned int width = smin.getBitWidth(); + APInt umin, umax; + if (smin.isNonNegative() == smax.isNonNegative()) { + umin = smin.ult(smax) ? smin : smax; + umax = smin.ugt(smax) ? smin : smax; + } else { + umin = APInt::getMinValue(width); + umax = APInt::getMaxValue(width); + } + return {umin, umax, smin, smax}; +} + +ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin, + const APInt &umax) { + unsigned int width = umin.getBitWidth(); + APInt smin, smax; + if (umin.isNonNegative() == umax.isNonNegative()) { + smin = umin.slt(umax) ? umin : umax; + smax = umin.sgt(umax) ? umin : umax; + } else { + smin = APInt::getSignedMinValue(width); + smax = APInt::getSignedMaxValue(width); + } + return {umin, umax, smin, smax}; +} + +ConstantIntRanges +ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const { + // "Not an integer" poisons everything and also cannot be fed to comparison + // operators. + if (umin().getBitWidth() == 0) + return *this; + if (other.umin().getBitWidth() == 0) + return other; + + const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin(); + const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax(); + const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin(); + const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax(); + + return {uminUnion, umaxUnion, sminUnion, smaxUnion}; +} + +Optional ConstantIntRanges::getConstantValue() const { + // Note: we need to exclude the trivially-equal width 0 values here. + if (umin() == umax() && umin().getBitWidth() != 0) + return umin(); + if (smin() == smax() && smin().getBitWidth() != 0) + return smin(); + return None; +} + +raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) { + return os << "unsigned : [" << range.umin() << ", " << range.umax() + << "] signed : [" << range.smin() << ", " << range.smax() << "]"; +} diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s + +// CHECK-LABEL: func @constant +// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index} +// CHECK: return %[[cst]] +func.func @constant() -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + func.return %0 : index +} + +// CHECK-LABEL: func @increment +// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index} +// CHECK: return %[[cst]] +func.func @increment() -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } + %1 = test.increment %0 + func.return %1 : index +} + +// CHECK-LABEL: func @maybe_increment +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @maybe_increment(%arg0 : i1) -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + %1 = scf.if %arg0 -> index { + scf.yield %0 : index + } else { + %2 = test.increment %0 + scf.yield %2 : index + } + %3 = test.reflect_bounds %1 + func.return %3 : index +} + +// CHECK-LABEL: func @maybe_increment_br +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @maybe_increment_br(%arg0 : i1) -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + cf.cond_br %arg0, ^bb0, ^bb1 +^bb0: + %1 = test.increment %0 + cf.br ^bb2(%1 : index) +^bb1: + cf.br ^bb2(%0 : index) +^bb2(%2 : index): + %3 = test.reflect_bounds %2 + func.return %3 : index +} + +// CHECK-LABEL: func @for_bounds +// CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} +func.func @for_bounds() -> index { + %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index} + %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index, + smin = 1 : index, smax = 1 : index} + %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index, + smin = 2 : index, smax = 2 : index} + + %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index { + scf.yield %arg0 : index + } + %1 = test.reflect_bounds %0 + func.return %1 : index +} + +// CHECK-LABEL: func @no_analysis_of_loop_variants +// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index} +func.func @no_analysis_of_loop_variants() -> index { + %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index} + %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index, + smin = 1 : index, smax = 1 : index} + %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index, + smin = 2 : index, smax = 2 : index} + + %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index { + %1 = test.increment %arg2 + scf.yield %1 : index + } + %2 = test.reflect_bounds %0 + func.return %2 : index +} + +// CHECK-LABEL: func @region_args +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @region_args() { + test.with_bounds_region { umin = 3 : index, umax = 4 : index, + smin = 3 : index, smax = 4 : index } %arg0 { + %0 = test.reflect_bounds %arg0 + } + func.return +} + +// CHECK-LABEL: func @func_args_unbound +// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index} +func.func @func_args_unbound(%arg0 : index) -> index { + %0 = test.reflect_bounds %arg0 + func.return %0 : index +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -62,6 +62,7 @@ MLIRFunc MLIRFuncTransforms MLIRIR + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRLinalg MLIRLinalgTransforms diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -33,6 +33,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/DerivedAttributeOpInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -14,15 +14,21 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -1396,6 +1402,67 @@ return success(); } +//===----------------------------------------------------------------------===// +// Test InferIntRangeInterface +//===----------------------------------------------------------------------===// + +void TestWithBoundsOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); +} + +ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the input argument + OpAsmParser::Argument argInfo; + argInfo.type = parser.getBuilder().getIndexType(); + if (failed(parser.parseArgument(argInfo))) + return failure(); + + // Parse the body region, and reuse the operand info as the argument info. + Region *body = result.addRegion(); + return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); +} + +void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { + p.printOptionalAttrDict((*this)->getAttrs()); + p << ' '; + p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, + /*omitType=*/true); + p << ' '; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + +void TestWithBoundsRegionOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + Value arg = getRegion().getArgument(0); + setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); +} + +void TestIncrementOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + const ConstantIntRanges &range = argRanges[0]; + APInt one(range.umin().getBitWidth(), 1); + setResultRanges(getResult(), + {range.umin().uadd_sat(one), range.umax().uadd_sat(one), + range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); +} + +void TestReflectBoundsOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + const ConstantIntRanges &range = argRanges[0]; + MLIRContext *ctx = getContext(); + Builder b(ctx); + setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); + setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); + setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); + setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); + setResultRanges(getResult(), range); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -789,7 +790,7 @@ def CustomResultsNameOp : TEST_Op<"custom_result_name", [DeclareOpInterfaceMethods]> { - let arguments = (ins + let arguments = (ins Variadic:$optional, StrArrayAttr:$names ); @@ -2885,4 +2886,51 @@ }]; } +//===----------------------------------------------------------------------===// +// Test InferIntRangeInterface +//===----------------------------------------------------------------------===// +def TestWithBoundsOp : TEST_Op<"with_bounds", + [DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins IndexAttr:$umin, + IndexAttr:$umax, + IndexAttr:$smin, + IndexAttr:$smax); + let results = (outs Index:$fakeVal); + + let assemblyFormat = "attr-dict"; +} + +def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region", + [DeclareOpInterfaceMethods, + SingleBlock, NoTerminator]> { + let arguments = (ins IndexAttr:$umin, + IndexAttr:$umax, + IndexAttr:$smin, + IndexAttr:$smax); + // The region has one argument of index type + let regions = (region SizedRegion<1>:$region); + let hasCustomAssemblyFormat = 1; +} + +def TestIncrementOp : TEST_Op<"increment", + [DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins Index:$value); + let results = (outs Index:$result); + + let assemblyFormat = "attr-dict $value"; +} + +def TestReflectBoundsOp : TEST_Op<"reflect_bounds", + [DeclareOpInterfaceMethods]> { + let arguments = (ins Index:$value, + OptionalAttr:$umin, + OptionalAttr:$umax, + OptionalAttr:$smin, + OptionalAttr:$smax); + let results = (outs Index:$result); + + let assemblyFormat = "attr-dict $value"; +} #endif // TEST_OPS diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp + TestIntRangeInference.cpp EXCLUDE_FROM_LIBMLIR @@ -10,6 +11,8 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms LINK_LIBS PUBLIC + MLIRAnalysis + MLIRInferIntRangeInterface MLIRTestDialect MLIRTransforms ) diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp @@ -0,0 +1,115 @@ +//===- TestIntRangeInference.cpp - Create consts from range inference ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// TODO: This pass is needed to test integer range inference until that +// functionality has been integrated into SCCP. +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/FoldUtils.h" + +using namespace mlir; + +/// Patterned after SCCP +static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, + OpBuilder &b, OperationFolder &folder, + Value value) { + Optional maybeInferredRange = analysis.getResult(value); + if (!maybeInferredRange) + return failure(); + const ConstantIntRanges &inferredRange = maybeInferredRange.getValue(); + Optional maybeConstValue = inferredRange.getConstantValue(); + if (!maybeConstValue.hasValue()) + return failure(); + + Operation *maybeDefiningOp = value.getDefiningOp(); + Dialect *valueDialect = + maybeDefiningOp ? maybeDefiningOp->getDialect() + : value.getParentRegion()->getParentOp()->getDialect(); + Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); + Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, + value.getType(), value.getLoc()); + if (!constant) + return failure(); + + value.replaceAllUsesWith(constant); + return success(); +} + +static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : llvm::reverse(region)) + worklist.push_back(&block); + }; + + OpBuilder builder(context); + OperationFolder folder(context); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + builder.setInsertionPoint(&op); + + // Replace any result with constants. + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= + succeeded(replaceWithConstant(analysis, builder, folder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + op.erase(); + continue; + } + + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + (void)replaceWithConstant(analysis, builder, folder, arg); + } +} + +namespace { +struct TestIntRangeInference + : PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference) + + StringRef getArgument() const final { return "test-int-range-inference"; } + StringRef getDescription() const final { + return "Test integer range inference analysis"; + } + + void runOnOperation() override { + Operation *op = getOperation(); + IntRangeAnalysis analysis(op); + rewrite(analysis, op->getContext(), op->getRegions()); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestIntRangeInference() { + PassRegistration(); +} +} // end namespace test +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -79,6 +79,7 @@ void registerTestExpandMathPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); +void registerTestIntRangeInference(); void registerTestIRVisitorsPass(); void registerTestGenericIRVisitorsPass(); void registerTestGenericIRVisitorsInterruptPass(); @@ -175,6 +176,7 @@ mlir::test::registerTestExpandMathPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); + mlir::test::registerTestIntRangeInference(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces(); diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt --- a/mlir/unittests/Interfaces/CMakeLists.txt +++ b/mlir/unittests/Interfaces/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRInterfacesTests ControlFlowInterfacesTest.cpp DataLayoutInterfacesTest.cpp + InferIntRangeInterfaceTest.cpp InferTypeOpInterfaceTest.cpp ) @@ -10,6 +11,7 @@ MLIRDataLayoutInterfaces MLIRDLTI MLIRFunc + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRParser ) diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp @@ -0,0 +1,99 @@ +//===- InferIntRangeInterfaceTest.cpp - Unit Tests for InferIntRange... --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "llvm/ADT/APInt.h" +#include + +#include + +using namespace mlir; + +TEST(IntRangeAttrs, BasicConstructors) { + APInt zero = APInt::getZero(64); + APInt two(64, 2); + APInt three(64, 3); + ConstantIntRanges boundedAbove(zero, two, zero, three); + EXPECT_EQ(boundedAbove.umin(), zero); + EXPECT_EQ(boundedAbove.umax(), two); + EXPECT_EQ(boundedAbove.smin(), zero); + EXPECT_EQ(boundedAbove.smax(), three); +} + +TEST(IntRangeAttrs, FromUnsigned) { + APInt zero = APInt::getZero(64); + APInt maxInt = APInt::getSignedMaxValue(64); + APInt minInt = APInt::getSignedMinValue(64); + APInt minIntPlusOne = minInt + 1; + + ConstantIntRanges canPortToSigned = + ConstantIntRanges::fromUnsigned(zero, maxInt); + EXPECT_EQ(canPortToSigned.smin(), zero); + EXPECT_EQ(canPortToSigned.smax(), maxInt); + + ConstantIntRanges cantPortToSigned = + ConstantIntRanges::fromUnsigned(zero, minInt); + EXPECT_EQ(cantPortToSigned.smin(), minInt); + EXPECT_EQ(cantPortToSigned.smax(), maxInt); + + ConstantIntRanges signedNegative = + ConstantIntRanges::fromUnsigned(minInt, minIntPlusOne); + EXPECT_EQ(signedNegative.smin(), minInt); + EXPECT_EQ(signedNegative.smax(), minIntPlusOne); +} + +TEST(IntRangeAttrs, FromSigned) { + APInt zero = APInt::getZero(64); + APInt one = zero + 1; + APInt negOne = zero - 1; + APInt intMax = APInt::getSignedMaxValue(64); + APInt intMin = APInt::getSignedMinValue(64); + APInt uintMax = APInt::getMaxValue(64); + + ConstantIntRanges noUnsignedBound = + ConstantIntRanges::fromSigned(negOne, one); + EXPECT_EQ(noUnsignedBound.umin(), zero); + EXPECT_EQ(noUnsignedBound.umax(), uintMax); + + ConstantIntRanges positive = ConstantIntRanges::fromSigned(one, intMax); + EXPECT_EQ(positive.umin(), one); + EXPECT_EQ(positive.umax(), intMax); + + ConstantIntRanges negative = ConstantIntRanges::fromSigned(intMin, negOne); + EXPECT_EQ(negative.umin(), intMin); + EXPECT_EQ(negative.umax(), negOne); + + ConstantIntRanges preserved = ConstantIntRanges::fromSigned(zero, one); + EXPECT_EQ(preserved.umin(), zero); + EXPECT_EQ(preserved.umax(), one); +} + +TEST(IntRangeAttrs, Join) { + APInt zero = APInt::getZero(64); + APInt one = zero + 1; + APInt two = zero + 2; + APInt intMin = APInt::getSignedMinValue(64); + APInt intMax = APInt::getSignedMaxValue(64); + APInt uintMax = APInt::getMaxValue(64); + + ConstantIntRanges maximal(zero, uintMax, intMin, intMax); + ConstantIntRanges zeroOne(zero, one, zero, one); + + EXPECT_EQ(zeroOne.rangeUnion(maximal), maximal); + EXPECT_EQ(maximal.rangeUnion(zeroOne), maximal); + + EXPECT_EQ(zeroOne.rangeUnion(zeroOne), zeroOne); + + ConstantIntRanges oneTwo(one, two, one, two); + ConstantIntRanges zeroTwo(zero, two, zero, two); + EXPECT_EQ(zeroOne.rangeUnion(oneTwo), zeroTwo); + + ConstantIntRanges zeroOneUnsignedOnly(zero, one, intMin, intMax); + ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one); + EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal); +}