diff --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/IntRangeAnalysis.h @@ -0,0 +1,45 @@ +//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the dataflow analysis class for integer range inference +// so that it can be used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H +#define MLIR_ANALYSIS_INTRANGEANALYSIS_H + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" + +namespace mlir { +namespace detail { +class IntRangeAnalysisImpl; +} // end namespace detail + +class IntRangeAnalysis { +public: + /// Analyze all operations rooted under (but not including) + /// `topLevelOperation`. + IntRangeAnalysis(Operation *topLevelOperation); + + ~IntRangeAnalysis(); + + IntRangeAnalysis(IntRangeAnalysis &&other); + + /// Get inferred range for value `v` if one exists. + Optional getResult(Value v); + +private: + std::unique_ptr impl; +}; +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_interface(ControlFlowInterfaces) add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) +add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -0,0 +1,121 @@ +//===- InferIntRangeInterface.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the integer range inference interface +// defined in `InferIntRange.td` +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H +#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H + +#include "mlir/IR/OpDefinition.h" +#include + +namespace mlir { +/// A collection of integer attributes representing minimum and maximum bounds +/// on a scalar integer value. These bounds are inclusive on both ends, so +/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for +/// the unsigned and signed interpretations of values. +class ConstantIntRanges { +public: + /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax. + /// Non-integer values should be bounded by APInts of bitwidth 0. + ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin, + const APInt &smax) + : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) { + assert(uminVal.getBitWidth() == umaxVal.getBitWidth() && + umaxVal.getBitWidth() == sminVal.getBitWidth() && + sminVal.getBitWidth() == smaxVal.getBitWidth() && + "All bounds in the ranges must have the same bitwidth"); + } + + /// Convenience wrapper that allows providing [min, max] ranges if one has + /// a utility function that returns tuples of `APInt`s. + ConstantIntRanges(const std::tuple &urange, + const std::tuple &srange) + : ConstantIntRanges(std::get<0>(urange), std::get<1>(urange), + std::get<0>(srange), std::get<1>(srange)) {} + + /// The minimum value of an integer when it is interpreted as unsigned. + const APInt &umin() const; + + /// The maximum value of an integer when it is interpreted as unsigned. + const APInt &umax() const; + + /// The minimum value of an integer when it is interpreted as signed; + const APInt &smin() const; + + /// The maximum value of an integer when it is interpreted as signed. + const APInt &smax() const; + + /// Return the bitwidth that should be used for integer ranges describing + /// `type`. For concrete integer types, this is their bitwidth, for `index`, + /// this is the internal storage bitwidth of `index` attributes, and for + /// non-integer types this is 0. + static unsigned int getStorageBitwidth(Type type); + + /// Return the value used to indicate that a value is not an integer, which is + /// a range made up of width-0 APInts. + static ConstantIntRanges notAnInteger(); + + /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned + /// minimum and `max` is both the signed and unsigned maximum. + static ConstantIntRanges range(const APInt &min, const APInt &max); + + /// Create an `IntRangeAttrs` with the signed minimum and maximum equal + /// to `smin` and `smax`, where the unsigned bounds are constructed from the + /// signed ones if they correspond to a contigious range of bit patterns when + /// viewed as unsigned values and are left at [0, int_max()] otherwise. + static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax); + + /// Wrapper to alow for [min, max] tuples as arguments to fromSigned(). + static ConstantIntRanges fromSigned(const std::tuple &srange); + + /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal + /// to `umin` and `umax` and the signed part equal to `umin` and `umax` + /// unless the sign bit changes between the minimum and maximum. + static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax); + + /// Wrapper to allow for [min, max] tuples as arguments to fromUnsigned(). + static ConstantIntRanges fromUnsigned(const std::tuple &urange); + + /// Returns the union (computed separately for signed and unsigned bounds) + /// of `a` and `b`. + static ConstantIntRanges join(const ConstantIntRanges &a, + const ConstantIntRanges &b); + + /// Helper methods for when IntRangeAttrs is used in a dataflow analysis. + /// Creates a range width bitwidth 0. + static ConstantIntRanges getPessimisticValueState(MLIRContext *context); + + /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) + /// range that is used to mark the value v as unable to be analyzed further, + /// where t is the type of v. + static ConstantIntRanges getPessimisticValueState(Value v); + + /// If either the signed or unsigned interpretations of the range + /// indicate that the value it bounds is a constant, return that constant + /// value. + Optional getConstantValue() const; + + bool operator==(const ConstantIntRanges &other) const; + + friend raw_ostream &operator<<(raw_ostream &os, + const ConstantIntRanges &range); + +private: + APInt uminVal, umaxVal, sminVal, smaxVal; +}; + +using SetIntRangeFn = function_ref; +} // end namespace mlir + +#include "mlir/Interfaces/InferIntRangeInterface.h.inc" + +#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td @@ -0,0 +1,47 @@ +//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------===// +// +// Defines the interface for range analysis on scalar integers +// +//===-----------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE +#define MLIR_INTERFACES_INFERINTRANGEINTERFACE + +include "mlir/IR/OpBase.td" + +def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> { + let description = [{ + Allows operations to participate in range analysis for scalar integer values by + providing a methods that allows them to specify lower and upper bounds on their + result(s) given lower and upper bounds on their input(s) if known. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Infer the bounds on the results of this op given the bounds on its arguments. + For each result value or block argument (that isn't a branch argument, + since the dataflow analysis handles those case), the method should call + `setValueRange` with that `Value` as an argument. When `setValueRange` + is not called for some value, it will recieve a default value of [-inf, +inf] + (the unbounded range). + + This function will only be called when at least one result of the op is a + scalar integer value or the op has a region. + + `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS + order. Non-integer arguments will have the unbounded range in their + `argRanges` element. + }], + "void", "inferResultRanges", (ins + "::llvm::ArrayRef<::mlir::ConstantIntRanges>":$argRanges, + "::mlir::SetIntRangeFn":$setResultRanges) + >]; +} +#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -51,6 +51,9 @@ /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); +/// Create a pass to constant fold based on the results of range inferrence. +std::unique_ptr createFoldInferredConstantsPass(); + /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. std::unique_ptr createLoopInvariantCodeMotionPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -77,6 +77,11 @@ ]; } +def FoldInferredConstants : Pass<"fold-inferred-constants"> { + let summary = "Constant fold based on the results of integer range inference"; + let constructor = "mlir::createFoldInferredConstantsPass()"; +} + def Inliner : Pass<"inline"> { let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp + IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -16,6 +17,7 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp + IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -31,6 +33,7 @@ MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces MLIRViewLikeInterface diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp --- a/mlir/lib/Analysis/DataFlowAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -359,11 +359,20 @@ if (auto branch = dyn_cast(op)) return visitRegionBranchOperation(branch, operandLattices); - // If we can't, conservatively mark all regions as executable. - // TODO: Let the `visitOperation` method decide how to propagate - // information to the block arguments. - for (Region ®ion : op->getRegions()) - markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); + for (Region ®ion : op->getRegions()) { + analysis.visitNonControlFlowArguments(op, RegionSuccessor(®ion), + operandLattices); + // `visitNonControlFlowArguments` is required to define all of the region + // argument lattices. + assert(llvm::none_of( + region.getArguments(), + [&](Value value) { + return analysis.getLatticeElement(value).isUninitialized(); + }) && + "expected `visitNonControlFlowArguments` to define all argument " + "lattices"); + markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/false); + } } // If this op produces no results, it can't produce any constants. diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp @@ -0,0 +1,271 @@ +//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; + +namespace mlir { +namespace detail { +class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + +public: + /// Define bounds on the results or block arguments of the operation + /// based on the bounds on the arguments given in `operands` + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) final; + + /// Skip regions of branch ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + LogicalResult getSuccessorsForOperands( + BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Skip regions of branch or loop ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + void getSuccessorsForOperands( + RegionBranchOpInterface branch, Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Call the InferIntRangeInterface implementation for region-using ops + /// that implement it, and infer the bounds of loop induction variables + /// for ops that implement LoopLikeOPInterface. + ChangeResult visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) final; +}; +} // end namespace detail +} // end namespace mlir + +/// Given the results of getConstant{Lower,Upper}Bound() +/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for +/// that result if possible. +static APInt getLoopBoundFromFold(Optional loopBound, + Type boundType, + detail::IntRangeAnalysisImpl &analysis, + bool getUpper) { + unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); + if (loopBound.hasValue()) { + if (loopBound->is()) { + if (auto bound = + loopBound->get().dyn_cast_or_null()) + return bound.getValue(); + } else if (loopBound->is()) { + LatticeElement *lattice = + analysis.lookupLatticeElement(loopBound->get()); + if (lattice != nullptr) + return getUpper ? lattice->getValue().smax() + : lattice->getValue().smin(); + } + } + return getUpper ? APInt::getSignedMaxValue(width) + : APInt::getSignedMinValue(width); +} + +ChangeResult detail::IntRangeAnalysisImpl::visitOperation( + Operation *op, ArrayRef *> operands) { + ChangeResult result = ChangeResult::NoChange; + // Ignore non-integer outputs - return early if the op has no scalar + // integer results + bool hasIntegerResult = false; + for (Value v : op->getResults()) { + if (v.getType().isIntOrIndex()) + hasIntegerResult = true; + else + result |= markAllPessimisticFixpoint(v); + } + if (!hasIntegerResult) + return result; + + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue(); + })); + + auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LatticeElement &lattice = getLatticeElement(v); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + result |= lattice.join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedResult && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + result |= lattice.markPessimisticFixpoint(); + } + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + } else if (op->getNumRegions() == 0) { + // No regions + no result inference method -> unbounded results (ex. memory + // ops) + result |= markAllPessimisticFixpoint(op->getResults()); + } + return result; +} + +LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands( + BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) { + auto toConstantAttr = [&branch](auto enumPair) -> Attribute { + Optional maybeConstValue = + enumPair.value()->getValue().getConstantValue(); + + if (maybeConstValue) { + return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + *maybeConstValue); + } + return {}; + }; + SmallVector inferredConsts( + llvm::map_range(llvm::enumerate(operands), toConstantAttr)); + if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) { + successors.push_back(singleSucc); + return success(); + } + return failure(); +} + +void detail::IntRangeAnalysisImpl::getSuccessorsForOperands( + RegionBranchOpInterface branch, Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) { + auto toConstantAttr = [&branch](auto enumPair) -> Attribute { + Optional maybeConstValue = + enumPair.value()->getValue().getConstantValue(); + + if (maybeConstValue) { + return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + *maybeConstValue); + } + return {}; + }; + SmallVector inferredConsts( + llvm::map_range(llvm::enumerate(operands), toConstantAttr)); + branch.getSuccessorRegions(sourceIndex, inferredConsts, successors); +} + +ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) { + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue(); + })); + + ChangeResult result = ChangeResult::NoChange; + auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LatticeElement &lattice = getLatticeElement(v); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + result |= lattice.join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedValue && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + result |= lattice.markPessimisticFixpoint(); + } + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + return result; + } + + // Infer bounds for loop arguments that have static bounds + if (auto loop = dyn_cast(op)) { + Optional iv = loop.getSingleInductionVar(); + if (!iv.hasValue()) { + return ForwardDataFlowAnalysis< + ConstantIntRanges>::visitNonControlFlowArguments(op, region, + operands); + } + Optional lowerBound = loop.getSingleLowerBound(); + Optional upperBound = loop.getSingleUpperBound(); + Optional step = loop.getSingleStep(); + APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this, + /*getUpper=*/false); + APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this, + /*getUpper=*/true); + // Assume positivity for uniscoverable steps by way of getUpper = true. + APInt stepVal = + getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true); + + if (stepVal.isNegative()) { + std::swap(min, max); + } else { + // Correct the upper bound by subtracting 1 so that it becomes a <= bound, + // because loops do not generally include their upper bound. + max -= 1; + } + + LatticeElement &ivEntry = getLatticeElement(*iv); + return ivEntry.join(ConstantIntRanges::fromSigned(min, max)); + } + return ForwardDataFlowAnalysis< + ConstantIntRanges>::visitNonControlFlowArguments(op, region, operands); +} + +IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) { + impl = std::make_unique( + topLevelOperation->getContext()); + impl->run(topLevelOperation); +} + +IntRangeAnalysis::~IntRangeAnalysis() = default; +IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default; + +Optional IntRangeAnalysis::getResult(Value v) { + LatticeElement *result = impl->lookupLatticeElement(v); + if (result == nullptr || result->isUninitialized()) + return llvm::None; + return result->getValue(); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ CopyOpInterface.cpp DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp + InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp @@ -35,6 +36,7 @@ add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) +add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -0,0 +1,130 @@ +//===- InferIntRangeInterface.cpp - Integer range inference interface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc" + +using namespace mlir; + +const APInt &ConstantIntRanges::umin() const { return uminVal; } + +const APInt &ConstantIntRanges::umax() const { return umaxVal; } + +const APInt &ConstantIntRanges::smin() const { return sminVal; } + +const APInt &ConstantIntRanges::smax() const { return smaxVal; } + +unsigned int ConstantIntRanges::getStorageBitwidth(Type type) { + if (type.isIndex()) + return IndexType::kInternalStorageBitWidth; + if (auto integerType = type.dyn_cast()) + return integerType.getWidth(); + // Non-integer types have their bounds stored in width 0 `APInt`s. + return 0; +} + +ConstantIntRanges ConstantIntRanges::notAnInteger() { + APInt noValue = APInt(/*numbits=*/0U, /*val=*/0LL); + return {noValue, noValue, noValue, noValue}; +} + +ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) { + return {min, max, min, max}; +} + +ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin, + const APInt &smax) { + unsigned int width = smin.getBitWidth(); + APInt umin, umax; + if (smin.isNonNegative() == smax.isNonNegative()) { + umin = smin.ult(smax) ? smin : smax; + umax = smin.ugt(smax) ? smin : smax; + } else { + umin = APInt::getMinValue(width); + umax = APInt::getMaxValue(width); + } + return {umin, umax, smin, smax}; +} + +ConstantIntRanges +ConstantIntRanges::fromSigned(const std::tuple &srange) { + return ConstantIntRanges::fromSigned(std::get<0>(srange), + std::get<1>(srange)); +} + +ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin, + const APInt &umax) { + unsigned int width = umin.getBitWidth(); + APInt smin, smax; + if (umin.isNonNegative() == umax.isNonNegative()) { + smin = umin.slt(umax) ? umin : umax; + smax = umin.sgt(umax) ? umin : umax; + } else { + smin = APInt::getSignedMinValue(width); + smax = APInt::getSignedMaxValue(width); + } + return {umin, umax, smin, smax}; +} + +ConstantIntRanges +ConstantIntRanges::fromUnsigned(const std::tuple &urange) { + return ConstantIntRanges::fromUnsigned(std::get<0>(urange), + std::get<1>(urange)); +} + +ConstantIntRanges ConstantIntRanges::join(const ConstantIntRanges &a, + const ConstantIntRanges &b) { + // "Not an integer" poisons everything and also cannot be fed to comparison + // operators. + if (a.umin().getBitWidth() == 0) + return a; + if (b.umin().getBitWidth() == 0) + return b; + + const APInt &umin = a.umin().ult(b.umin()) ? a.umin() : b.umin(); + const APInt &umax = a.umax().ugt(b.umax()) ? a.umax() : b.umax(); + const APInt &smin = a.smin().slt(b.smin()) ? a.smin() : b.smin(); + const APInt &smax = a.smax().sgt(b.smax()) ? a.smax() : b.smax(); + + return {umin, umax, smin, smax}; +} + +ConstantIntRanges +ConstantIntRanges::getPessimisticValueState(MLIRContext *context) { + return notAnInteger(); +} + +ConstantIntRanges ConstantIntRanges::getPessimisticValueState(Value v) { + unsigned int width = getStorageBitwidth(v.getType()); + APInt umin = APInt::getMinValue(width); + APInt umax = APInt::getMaxValue(width); + APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin; + APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax; + return {umin, umax, smin, smax}; +} + +Optional ConstantIntRanges::getConstantValue() const { + // Note: we need to exclude the trivially-equal width 0 values here. + if (umin() == umax() && umin().getBitWidth() != 0) + return umin(); + if (smin() == smax() && smin().getBitWidth() != 0) + return smin(); + return None; +} + +bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const { + return umin().getBitWidth() == other.umin().getBitWidth() && + umin() == other.umin() && umax() == other.umax() && + smin() == other.smin() && smax() == other.smax(); +} + +raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) { + return os << "unsigned : [" << range.umin() << ", " << range.umax() + << "] signed : [" << range.smin() << ", " << range.smax() << "]"; +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Canonicalizer.cpp ControlFlowSink.cpp CSE.cpp + FoldInferredConstants.cpp Inliner.cpp LocationSnapshot.cpp LoopInvariantCodeMotion.cpp @@ -23,7 +24,9 @@ LINK_LIBS PUBLIC MLIRAnalysis MLIRCopyOpInterface + MLIRInferIntRangeInterface MLIRLoopLikeInterface + MLIRSideEffectInterfaces MLIRPass MLIRSupport MLIRTransformUtils diff --git a/mlir/lib/Transforms/FoldInferredConstants.cpp b/mlir/lib/Transforms/FoldInferredConstants.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/FoldInferredConstants.cpp @@ -0,0 +1,101 @@ +//===- FoldInferredConstants.cpp - Pass to materialize constants that can be +// inferred by range analysis --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +/// Patterned after SCCP.cpp +static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, + OpBuilder &b, OperationFolder &folder, + Value value) { + Optional maybeInferredRange = analysis.getResult(value); + if (!maybeInferredRange) + return failure(); + const ConstantIntRanges &inferredRange = maybeInferredRange.getValue(); + Optional maybeConstValue = inferredRange.getConstantValue(); + if (!maybeConstValue.hasValue()) + return failure(); + + Operation *maybeDefiningOp = value.getDefiningOp(); + Dialect *valueDialect = + maybeDefiningOp ? maybeDefiningOp->getDialect() + : value.getParentRegion()->getParentOp()->getDialect(); + Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); + Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, + value.getType(), value.getLoc()); + if (!constant) + return failure(); + + value.replaceAllUsesWith(constant); + return success(); +} + +static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : llvm::reverse(region)) + worklist.push_back(&block); + }; + + OpBuilder builder(context); + OperationFolder folder(context); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + builder.setInsertionPoint(&op); + + // Replace any result with constants. + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= + succeeded(replaceWithConstant(analysis, builder, folder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + op.erase(); + continue; + } + + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + (void)replaceWithConstant(analysis, builder, folder, arg); + } +} + +namespace { +struct FoldInferredConstantsPass + : public FoldInferredConstantsBase { + void runOnOperation() override { + Operation *op = getOperation(); + IntRangeAnalysis analysis(op); + rewrite(analysis, op->getContext(), op->getRegions()); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::createFoldInferredConstantsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt -fold-inferred-constants %s | FileCheck %s + +// CHECK-LABEL: func @constant +// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index} +// CHECK: return %[[cst]] +func.func @constant() -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + func.return %0 : index +} + +// CHECK-LABEL: func @increment +// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index} +// CHECK: return %[[cst]] +func.func @increment() -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index } + %1 = test.increment %0 + func.return %1 : index +} + +// CHECK-LABEL: func @maybe_increment +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @maybe_increment(%arg0 : i1) -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + %1 = scf.if %arg0 -> index { + scf.yield %0 : index + } else { + %2 = test.increment %0 + scf.yield %2 : index + } + %3 = test.reflect_bounds %1 + func.return %3 : index +} + +// CHECK-LABEL: func @maybe_increment_br +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @maybe_increment_br(%arg0 : i1) -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + cf.cond_br %arg0, ^bb0, ^bb1 +^bb0: + %1 = test.increment %0 + cf.br ^bb2(%1 : index) +^bb1: + cf.br ^bb2(%0 : index) +^bb2(%2 : index): + %3 = test.reflect_bounds %2 + func.return %3 : index +} + +// CHECK-LABEL: func @for_bounds +// CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} +func.func @for_bounds() -> index { + %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index} + %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index, + smin = 1 : index, smax = 1 : index} + %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index, + smin = 2 : index, smax = 2 : index} + + %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index { + scf.yield %arg0 : index + } + %1 = test.reflect_bounds %0 + func.return %1 : index +} + +// CHECK-LABEL: func @no_analysis_of_loop_variants +// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index} +func.func @no_analysis_of_loop_variants() -> index { + %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index} + %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index, + smin = 1 : index, smax = 1 : index} + %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index, + smin = 2 : index, smax = 2 : index} + + %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index { + %1 = test.increment %arg2 + scf.yield %1 : index + } + %2 = test.reflect_bounds %0 + func.return %2 : index +} + +// CHECK-LABEL: func @region_args +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @region_args() { + test.with_bounds_region { umin = 3 : index, umax = 4 : index, + smin = 3 : index, smax = 4 : index } %arg0 { + %0 = test.reflect_bounds %arg0 + } + func.return +} + +// CHECK-LABEL: func @func_args_unbound +// CHECK: test.reflect_bounds {smax = 9223372036854775807 : index, smin = -9223372036854775808 : index, umax = -1 : index, umin = 0 : index} +func.func @func_args_unbound(%arg0 : index) -> index { + %0 = test.reflect_bounds %arg0 + func.return %0 : index +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -62,6 +62,7 @@ MLIRFunc MLIRFuncTransforms MLIRIR + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRLinalg MLIRLinalgTransforms diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -33,6 +33,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/DerivedAttributeOpInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -14,15 +14,21 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -1402,6 +1408,66 @@ return success(); } +//===----------------------------------------------------------------------===// +// Test InferIntRangeInterface +//===----------------------------------------------------------------------===// + +void TestWithBoundsOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); +} + +ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the input argument + OpAsmParser::Argument argInfo; + argInfo.type = parser.getBuilder().getIndexType(); + parser.parseArgument(argInfo); + + // Parse the body region, and reuse the operand info as the argument info. + Region *body = result.addRegion(); + return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); +} + +void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { + p.printOptionalAttrDict((*this)->getAttrs()); + p << ' '; + p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, + /*omitType=*/true); + p << ' '; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + +void TestWithBoundsRegionOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + Value arg = getRegion().getArgument(0); + setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); +} + +void TestIncrementOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + const ConstantIntRanges &range = argRanges[0]; + APInt one(range.umin().getBitWidth(), 1); + setResultRanges(getResult(), + {range.umin().uadd_sat(one), range.umax().uadd_sat(one), + range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); +} + +void TestReflectBoundsOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + const ConstantIntRanges &range = argRanges[0]; + MLIRContext *ctx = getContext(); + Builder b(ctx); + setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); + setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); + setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); + setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); + setResultRanges(getResult(), range); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -789,7 +790,7 @@ def CustomResultsNameOp : TEST_Op<"custom_result_name", [DeclareOpInterfaceMethods]> { - let arguments = (ins + let arguments = (ins Variadic:$optional, StrArrayAttr:$names ); @@ -2873,4 +2874,51 @@ }]; } +//===----------------------------------------------------------------------===// +// Test InferIntRangeInterface +//===----------------------------------------------------------------------===// +def TestWithBoundsOp : TEST_Op<"with_bounds", + [DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins IndexAttr:$umin, + IndexAttr:$umax, + IndexAttr:$smin, + IndexAttr:$smax); + let results = (outs Index:$fakeVal); + + let assemblyFormat = "attr-dict"; +} + +def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region", + [DeclareOpInterfaceMethods, + SingleBlock, NoTerminator]> { + let arguments = (ins IndexAttr:$umin, + IndexAttr:$umax, + IndexAttr:$smin, + IndexAttr:$smax); + // The region has one argument of index type + let regions = (region SizedRegion<1>:$region); + let hasCustomAssemblyFormat = 1; +} + +def TestIncrementOp : TEST_Op<"increment", + [DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins Index:$value); + let results = (outs Index:$result); + + let assemblyFormat = "attr-dict $value"; +} + +def TestReflectBoundsOp : TEST_Op<"reflect_bounds", + [DeclareOpInterfaceMethods]> { + let arguments = (ins Index:$value, + OptionalAttr:$umin, + OptionalAttr:$umax, + OptionalAttr:$smin, + OptionalAttr:$smax); + let results = (outs Index:$result); + + let assemblyFormat = "attr-dict $value"; +} #endif // TEST_OPS diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt --- a/mlir/unittests/Interfaces/CMakeLists.txt +++ b/mlir/unittests/Interfaces/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRInterfacesTests ControlFlowInterfacesTest.cpp DataLayoutInterfacesTest.cpp + InferIntRangeInterfaceTest.cpp InferTypeOpInterfaceTest.cpp ) @@ -10,6 +11,7 @@ MLIRDataLayoutInterfaces MLIRDLTI MLIRFunc + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRParser ) diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp @@ -0,0 +1,99 @@ +//===- InferIntRangeInterfaceTest.cpp - Unit Tests for InferIntRange... --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "llvm/ADT/APInt.h" +#include + +#include + +using namespace mlir; + +TEST(IntRangeAttrs, BasicConstructors) { + APInt zero = APInt::getZero(64); + APInt two(64, 2); + APInt three(64, 3); + ConstantIntRanges boundedAbove(zero, two, zero, three); + EXPECT_EQ(boundedAbove.umin(), zero); + EXPECT_EQ(boundedAbove.umax(), two); + EXPECT_EQ(boundedAbove.smin(), zero); + EXPECT_EQ(boundedAbove.smax(), three); +} + +TEST(IntRangeAttrs, FromUnsigned) { + APInt zero = APInt::getZero(64); + APInt maxInt = APInt::getSignedMaxValue(64); + APInt minInt = APInt::getSignedMinValue(64); + APInt minIntPlusOne = minInt + 1; + + ConstantIntRanges canPortToSigned = + ConstantIntRanges::fromUnsigned(zero, maxInt); + EXPECT_EQ(canPortToSigned.smin(), zero); + EXPECT_EQ(canPortToSigned.smax(), maxInt); + + ConstantIntRanges cantPortToSigned = + ConstantIntRanges::fromUnsigned(zero, minInt); + EXPECT_EQ(cantPortToSigned.smin(), minInt); + EXPECT_EQ(cantPortToSigned.smax(), maxInt); + + ConstantIntRanges signedNegative = + ConstantIntRanges::fromUnsigned(minInt, minIntPlusOne); + EXPECT_EQ(signedNegative.smin(), minInt); + EXPECT_EQ(signedNegative.smax(), minIntPlusOne); +} + +TEST(IntRangeAttrs, FromSigned) { + APInt zero = APInt::getZero(64); + APInt one = zero + 1; + APInt negOne = zero - 1; + APInt intMax = APInt::getSignedMaxValue(64); + APInt intMin = APInt::getSignedMinValue(64); + APInt uintMax = APInt::getMaxValue(64); + + ConstantIntRanges noUnsignedBound = + ConstantIntRanges::fromSigned(negOne, one); + EXPECT_EQ(noUnsignedBound.umin(), zero); + EXPECT_EQ(noUnsignedBound.umax(), uintMax); + + ConstantIntRanges positive = ConstantIntRanges::fromSigned(one, intMax); + EXPECT_EQ(positive.umin(), one); + EXPECT_EQ(positive.umax(), intMax); + + ConstantIntRanges negative = ConstantIntRanges::fromSigned(intMin, negOne); + EXPECT_EQ(negative.umin(), intMin); + EXPECT_EQ(negative.umax(), negOne); + + ConstantIntRanges preserved = ConstantIntRanges::fromSigned(zero, one); + EXPECT_EQ(preserved.umin(), zero); + EXPECT_EQ(preserved.umax(), one); +} + +TEST(IntRangeAttrs, Join) { + APInt zero = APInt::getZero(64); + APInt one = zero + 1; + APInt two = zero + 2; + APInt intMin = APInt::getSignedMinValue(64); + APInt intMax = APInt::getSignedMaxValue(64); + APInt uintMax = APInt::getMaxValue(64); + + ConstantIntRanges maximal(zero, uintMax, intMin, intMax); + ConstantIntRanges zeroOne(zero, one, zero, one); + + EXPECT_EQ(ConstantIntRanges::join(zeroOne, maximal), maximal); + EXPECT_EQ(ConstantIntRanges::join(maximal, zeroOne), maximal); + + EXPECT_EQ(ConstantIntRanges::join(zeroOne, zeroOne), zeroOne); + + ConstantIntRanges oneTwo(one, two, one, two); + ConstantIntRanges zeroTwo(zero, two, zero, two); + EXPECT_EQ(ConstantIntRanges::join(zeroOne, oneTwo), zeroTwo); + + ConstantIntRanges zeroOneUnsigned(zero, one, intMin, intMax); + ConstantIntRanges zeroOneSigned(zero, uintMax, zero, one); + EXPECT_EQ(ConstantIntRanges::join(zeroOneUnsigned, zeroOneSigned), maximal); +}