diff --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/IntRangeAnalysis.h @@ -0,0 +1,44 @@ +//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the dataflow analysis class for integer range inference +// so that it can be used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H +#define MLIR_ANALYSIS_INTRANGEANALYSIS_H + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" + +namespace mlir { +namespace detail { +class IntRangeAnalysisImpl; +} // end namespace detail + +class IntRangeAnalysis { +public: + IntRangeAnalysis(MLIRContext *ctx); + ~IntRangeAnalysis(); + IntRangeAnalysis(IntRangeAnalysis &&other); + + /// Analyze all operations rooted under (but not including) + /// `topLevelOperation`. + void run(Operation *topLevelOperation); + /// Get inferred range for value `v` if one exists. + Optional getResult(Value v); + +private: + std::unique_ptr impl; +}; +} // end namespace mlir + +#endif diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_interface(ControlFlowInterfaces) add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) +add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -0,0 +1,106 @@ +//===- InferIntRangeInterface.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the integer range inference interface +// defined in `InferIntRange.td` +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE_H +#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H + +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/APInt.h" +#include + +namespace mlir { +class Value; + +/// A collection of integer attributes representing minimum and maximum bounds +/// on a scalar integer value. These bounds are inclusive on both ends, so +/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for +/// the unsigned and signed interpretations of values. +class IntRangeAttrs { +public: + /// Return a range with no bounds for the signed and unsigned interpretations. + IntRangeAttrs() : uminVal(), umaxVal(), sminVal(), smaxVal() {} + /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax, + /// with llvm::None being considered +-infinity. + IntRangeAttrs(const Optional &umin, const Optional &umax, + const Optional &smin, const Optional &smax) + : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) { + // Unsigned integers that have upper bounds but no lower bound have a lower + // bound of 0. + if (umax && !umin) { + uminVal = APInt::getZero(umax->getBitWidth()); + } + } + + /// Convenience wrapper that allows providing [min, max] ranges. + IntRangeAttrs(const std::tuple, Optional> &urange, + const std::tuple, Optional> &srange) + : IntRangeAttrs(std::get<0>(urange), std::get<1>(urange), + std::get<0>(srange), std::get<1>(srange)) {} + + /// The minimum value of an integer when it is interpreted as unsigned. + const Optional &umin() const; + /// The maximum value of an integer when it is interpreted as unsigned. + const Optional &umax() const; + /// The minimum value of an integer when it is interpreted as signed; + const Optional &smin() const; + /// The maximum value of an integer when it is interpreted as signed. + const Optional &smax() const; + + /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned + /// minimum and `max` is both the signed and unsigned maximum. + static IntRangeAttrs range(const Optional &min, + const Optional &max); + + /// Create an `IntRangeAttrs` with the signed minimum and maximum equal + /// to `smin` and `smax`, where the unsigned bounds are constructed from the + /// signed ones if the signed bounds correspond to a contiguous range of + /// values within [0, int_max] or are left unset otherwise. + static IntRangeAttrs fromSigned(const Optional &smin, + const Optional &smax); + static IntRangeAttrs + fromSigned(const std::tuple, Optional> &srange); + + /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal + /// to `umin` and `umax` and the signed part equal to `umin` and `umax` + /// unless the sign bit changes between the minimum and maximum. + static IntRangeAttrs fromUnsigned(const Optional &umin, + const Optional &umax); + static IntRangeAttrs + fromUnsigned(const std::tuple, Optional> &urange); + + /// Returns the union (computed separately for signed and unsigned bounds) + /// of `a` and `b`. + static IntRangeAttrs join(const IntRangeAttrs &a, const IntRangeAttrs &b); + /// Helper methods for when IntRangeAttrs is used in a dataflow analysis. + static IntRangeAttrs getPessimisticValueState(MLIRContext *context); + static IntRangeAttrs getPessimisticValueState(Value v); + + /// If either the signed or unsigned interpretations of the range + /// indicate that the value it bounds is a constant, return that constant + /// value. + Optional getConstantValue() const; + + bool operator==(const IntRangeAttrs &other) const; + + friend raw_ostream &operator<<(raw_ostream &os, const IntRangeAttrs &range); + +private: + Optional uminVal, umaxVal, sminVal, smaxVal; +}; + +using SetIntRangeFn = llvm::function_ref; +} // end namespace mlir + +#include "mlir/Interfaces/InferIntRangeInterface.h.inc" + +#endif // MLIR_INTERFACES_INFERINTRANGEINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td @@ -0,0 +1,47 @@ +//===- InferIntRangeInterface.td - Integer Range Inference --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===-----------------------------------------------------===// +// +// Defines the interface for range analysis on scalar integers +// +//===-----------------------------------------------------===// + +#ifndef MLIR_INTERFACES_INFERINTRANGEINTERFACE +#define MLIR_INTERFACES_INFERINTRANGEINTERFACE + +include "mlir/IR/OpBase.td" + +def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> { + let description = [{ + Allows operations to participate in range analysis for scalar integer values by + providing a methods that allows them to specify lower and upper bounds on their + result(s) given lower and upper bounds on their input(s) if known. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Infer the bounds on the results of this op given the bounds on its arguments. + For each result value or block argument (that isn't a branch argument, + since the dataflow analysis handles those case), the method should call + `setValueRange` with that `Value` as an argument. When `setValueRange` + is not called for some value, it will recieve a default value of [-inf, +inf] + (the unbounded range). + + This function will only be called when at least one result of the op is a + scalar integer value or the op has a region. + + `argRanges` contains one `IntRangeAttrs` for each argument to the op in ODS + order. Non-integer arguments will have the unbounded range in their + `argRanges` element. + }], + "void", "inferResultRanges", (ins + "::llvm::ArrayRef<::mlir::IntRangeAttrs>":$argRanges, + "::mlir::SetIntRangeFn":$setResultRanges) + >]; +} +#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -51,6 +51,9 @@ /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); +/// Create a pass to constant fold based on the results of range inferrence. +std::unique_ptr createFoldInferredConstantsPass(); + /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. std::unique_ptr createLoopInvariantCodeMotionPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -77,6 +77,11 @@ ]; } +def FoldInferredConstants : Pass<"fold-inferred-constants"> { + let summary = "Constant fold based on the results of integer range inference"; + let constructor = "mlir::createFoldInferredConstantsPass()"; +} + def Inliner : Pass<"inline"> { let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp + IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -16,6 +17,7 @@ CallGraph.cpp DataFlowAnalysis.cpp DataLayoutAnalysis.cpp + IntRangeAnalysis.cpp Liveness.cpp SliceAnalysis.cpp @@ -31,6 +33,7 @@ MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces MLIRViewLikeInterface diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp --- a/mlir/lib/Analysis/DataFlowAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -359,11 +359,20 @@ if (auto branch = dyn_cast(op)) return visitRegionBranchOperation(branch, operandLattices); - // If we can't, conservatively mark all regions as executable. - // TODO: Let the `visitOperation` method decide how to propagate - // information to the block arguments. - for (Region ®ion : op->getRegions()) - markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true); + for (Region ®ion : op->getRegions()) { + analysis.visitNonControlFlowArguments(op, RegionSuccessor(®ion), + operandLattices); + // `visitNonControlFlowArguments` is required to define all of the region + // argument lattices. + assert(llvm::none_of( + region.getArguments(), + [&](Value value) { + return analysis.getLatticeElement(value).isUninitialized(); + }) && + "expected `visitNonControlFlowArguments` to define all argument " + "lattices"); + markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/false); + } } // If this op produces no results, it can't produce any constants. diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp @@ -0,0 +1,268 @@ +//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; + +namespace mlir { +namespace detail { +class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis { + using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + +public: + /// Define bounds on the results or block arguments of the operation + /// based on the bounds on the arguments given in `operands` + ChangeResult + visitOperation(Operation *op, + ArrayRef *> operands) final; + + /// Skip regions of branch ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + LogicalResult + getSuccessorsForOperands(BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Skip regions of branch or loop ops when we can statically infer constant + /// values for operands to the branch op and said op tells us it's safe to do + /// so. + void + getSuccessorsForOperands(RegionBranchOpInterface branch, + Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) final; + + /// Infer bounds on loop bounds + ChangeResult visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) final; +}; +} // end namespace detail +} // end namespace mlir + +/// Given the results of getConstant{Lower,Upper}Bound() +/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for +/// that result if possible. +static Optional +getLoopBoundFromFold(Optional loopBound, + detail::IntRangeAnalysisImpl &analysis, bool getUpper) { + if (!loopBound.hasValue()) + return {}; + Optional result; + if (loopBound->is()) { + if (auto bound = + loopBound->get().dyn_cast_or_null()) + result = bound.getValue(); + } else if (loopBound->is()) { + LatticeElement *lattice = + analysis.lookupLatticeElement(loopBound->get()); + if (lattice != nullptr) + result = + getUpper ? lattice->getValue().smax() : lattice->getValue().smin(); + } + return result; +} + +ChangeResult detail::IntRangeAnalysisImpl::visitOperation( + Operation *op, ArrayRef *> operands) { + ChangeResult result = ChangeResult::NoChange; + // Ignore non-integer outputs - return early if the op has no scalar + // integer results + bool hasIntegerResult = false; + for (Value v : op->getResults()) { + if (v.getType().isIntOrIndex()) + hasIntegerResult = true; + else + result |= markAllPessimisticFixpoint(v); + } + if (!hasIntegerResult) + return result; + + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue(); + })); + + auto joinCallback = [&](Value v, const IntRangeAttrs &attrs) { + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LatticeElement &lattice = getLatticeElement(v); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + result |= lattice.join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedResult && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + result |= lattice.markPessimisticFixpoint(); + } + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + } else if (op->getNumRegions() == 0) { + // No regions + no result inference method -> unbounded results (ex. memory + // ops) + result |= markAllPessimisticFixpoint(op->getResults()); + } + return result; +} + +LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands( + BranchOpInterface branch, + ArrayRef *> operands, + SmallVectorImpl &successors) { + auto toConstantAttr = [&branch](auto enumPair) -> Attribute { + Optional maybeConstValue = + enumPair.value()->getValue().getConstantValue(); + + if (maybeConstValue) { + return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + *maybeConstValue); + } + return {}; + }; + SmallVector inferredConsts( + llvm::map_range(llvm::enumerate(operands), toConstantAttr)); + if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) { + successors.push_back(singleSucc); + return success(); + } + return failure(); +} + +void detail::IntRangeAnalysisImpl::getSuccessorsForOperands( + RegionBranchOpInterface branch, Optional sourceIndex, + ArrayRef *> operands, + SmallVectorImpl &successors) { + auto toConstantAttr = [&branch](auto enumPair) -> Attribute { + Optional maybeConstValue = + enumPair.value()->getValue().getConstantValue(); + + if (maybeConstValue) { + return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + *maybeConstValue); + } + return {}; + }; + SmallVector inferredConsts( + llvm::map_range(llvm::enumerate(operands), toConstantAttr)); + branch.getSuccessorRegions(sourceIndex, inferredConsts, successors); +} + +ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) { + if (auto inferrable = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for "); + LLVM_DEBUG(inferrable->print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + SmallVector argRanges( + llvm::map_range(operands, [](LatticeElement *val) { + return val->getValue(); + })); + + ChangeResult result = ChangeResult::NoChange; + auto joinCallback = [&](Value v, const IntRangeAttrs &attrs) { + LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LatticeElement &lattice = getLatticeElement(v); + Optional oldRange; + if (!lattice.isUninitialized()) + oldRange = lattice.getValue(); + result |= lattice.join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedValue && oldRange.hasValue() && + !(lattice.getValue() == *oldRange)) { + LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + result |= lattice.markPessimisticFixpoint(); + } + }; + + inferrable.inferResultRanges(argRanges, joinCallback); + return result; + } + + // Infer bounds for loop arguments that have static bounds + if (auto loop = dyn_cast(op)) { + Optional iv = loop.getSingleInductionVar(); + if (!iv.hasValue()) { + return ForwardDataFlowAnalysis< + IntRangeAttrs>::visitNonControlFlowArguments(op, region, operands); + } + Optional lowerBound = loop.getSingleLowerBound(); + Optional upperBound = loop.getSingleUpperBound(); + Optional step = loop.getSingleStep(); + Optional min = + getLoopBoundFromFold(lowerBound, *this, /*getUpper=*/false); + Optional max = + getLoopBoundFromFold(upperBound, *this, /*getUpper=*/true); + Optional stepVal = + getLoopBoundFromFold(step, *this, /*getUpper=*/false); + + if (stepVal && stepVal.getValue().isNegative()) { + std::swap(min, max); + } else if (max) { + // Correct the upper bound by subtracting 1 so that it becomes a <= bound, + // because loops do not generally include their upper bound. + max = *max - 1; + } + + LatticeElement &ivEntry = getLatticeElement(*iv); + return ivEntry.join(IntRangeAttrs::fromSigned(min, max)); + } + return ForwardDataFlowAnalysis::visitNonControlFlowArguments( + op, region, operands); +} + +IntRangeAnalysis::IntRangeAnalysis(MLIRContext *ctx) { + impl = std::make_unique(ctx); +} + +IntRangeAnalysis::~IntRangeAnalysis() = default; +IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default; + +void IntRangeAnalysis::run(Operation *topLevelOperation) { + impl->run(topLevelOperation); +} + +Optional IntRangeAnalysis::getResult(Value v) { + LatticeElement *result = impl->lookupLatticeElement(v); + if (result == nullptr || result->isUninitialized()) + return llvm::None; + return result->getValue(); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ CopyOpInterface.cpp DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp + InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp @@ -35,6 +36,7 @@ add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) +add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -0,0 +1,127 @@ +//===- InferIntRangeInterface.cpp - Integer range inference interface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.cpp.inc" + +using namespace mlir; + +const Optional &IntRangeAttrs::umin() const { return uminVal; } +const Optional &IntRangeAttrs::umax() const { return umaxVal; } +const Optional &IntRangeAttrs::smin() const { return sminVal; } +const Optional &IntRangeAttrs::smax() const { return smaxVal; } + +IntRangeAttrs IntRangeAttrs::range(const Optional &min, + const Optional &max) { + return {min, max, min, max}; +} + +IntRangeAttrs IntRangeAttrs::fromSigned(const Optional &smin, + const Optional &smax) { + Optional umin, umax; + if (smin && smax) { + if (smin->isNonNegative() == smax->isNonNegative()) { + umin = smin->ult(*smax) ? smin : smax; + umax = smin->ugt(*smax) ? smin : smax; + } + } else if (smin) { + if (smin.getValue().isNonNegative()) { + umin = smin; + umax = APInt::getSignedMaxValue(smin->getBitWidth()); + } + } else if (smax) { + if (smax.getValue().isNegative()) { + umin = APInt::getSignedMinValue(smax->getBitWidth()); + umax = smax; + } + } + return {umin, umax, smin, smax}; +} + +IntRangeAttrs IntRangeAttrs::fromSigned( + const std::tuple, Optional> &srange) { + return IntRangeAttrs::fromSigned(std::get<0>(srange), std::get<1>(srange)); +} + +IntRangeAttrs IntRangeAttrs::fromUnsigned(const Optional &umin, + const Optional &umax) { + Optional smin, smax; + // Note, unlike in fromSigned, we have no case for umin <= x < inf, + // since the upper bound is pontentially a negative number. + if (umin && umax) { + if (umin->isNonNegative() == umax->isNonNegative()) { + smin = umin->slt(*umax) ? umin : umax; + smax = umin->sgt(*umax) ? umin : umax; + } + } else if (umax) { + if (umax->isNonNegative()) { + smax = umax; + smin = APInt::getZero(umax->getBitWidth()); + } + } + return {umin, umax, smin, smax}; +} + +IntRangeAttrs IntRangeAttrs::fromUnsigned( + const std::tuple, Optional> &urange) { + return IntRangeAttrs::fromUnsigned(std::get<0>(urange), std::get<1>(urange)); +} + +IntRangeAttrs IntRangeAttrs::join(const IntRangeAttrs &a, + const IntRangeAttrs &b) { + Optional umin, umax, smin, smax; + if (a.umin() && b.umin()) + umin = a.umin()->ult(*b.umin()) ? a.umin() : b.umin(); + if (a.umax() && b.umax()) + umax = a.umax()->ugt(*b.umax()) ? a.umax() : b.umax(); + if (a.smin() && b.smin()) + smin = a.smin()->slt(*b.smin()) ? a.smin() : b.smin(); + if (a.smax() && b.smax()) + smax = a.smax()->sgt(*b.smax()) ? a.smax() : b.smax(); + + return {umin, umax, smin, smax}; +} + +IntRangeAttrs IntRangeAttrs::getPessimisticValueState(MLIRContext *context) { + return {}; +} + +IntRangeAttrs IntRangeAttrs::getPessimisticValueState(Value v) { + return getPessimisticValueState(v.getContext()); +} + +Optional IntRangeAttrs::getConstantValue() const { + if (umin() && umax() && umin() == umax()) + return umin(); + if (smin() && smax() && smin() == smax()) + return smin(); + return None; +} + +bool IntRangeAttrs::operator==(const IntRangeAttrs &other) const { + return umin() == other.umin() && umax() == other.umax() && + smin() == other.smin() && smax() == other.smax(); +} + +raw_ostream &mlir::operator<<(raw_ostream &os, const IntRangeAttrs &range) { + auto printValIfPresent = [&os](StringRef sign, const Optional &val) { + if (val) + os << val.getValue(); + else + os << sign << "inf"; + }; + os << "unsigned : ["; + printValIfPresent("-", range.umin()); + os << ", "; + printValIfPresent("+", range.umax()); + os << "] signed : ["; + printValIfPresent("-", range.smin()); + os << ", "; + printValIfPresent("+", range.smax()); + return os << "]"; +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Canonicalizer.cpp ControlFlowSink.cpp CSE.cpp + FoldInferredConstants.cpp Inliner.cpp LocationSnapshot.cpp LoopInvariantCodeMotion.cpp @@ -23,7 +24,9 @@ LINK_LIBS PUBLIC MLIRAnalysis MLIRCopyOpInterface + MLIRInferIntRangeInterface MLIRLoopLikeInterface + MLIRSideEffectInterfaces MLIRPass MLIRSupport MLIRTransformUtils diff --git a/mlir/lib/Transforms/FoldInferredConstants.cpp b/mlir/lib/Transforms/FoldInferredConstants.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/FoldInferredConstants.cpp @@ -0,0 +1,104 @@ +//===- FoldInferredConstants.cpp - Pass to materialize constants that can be +// inferred by range analysis --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/IntRangeAnalysis.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +/// Patterned after SCCP.cpp +static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis, + OpBuilder &b, OperationFolder &folder, + Value value) { + Optional maybeInferredRange = analysis.getResult(value); + if (!maybeInferredRange) + return failure(); + const IntRangeAttrs &inferredRange = maybeInferredRange.getValue(); + Optional maybeConstValue = inferredRange.getConstantValue(); + if (!maybeConstValue.hasValue()) + return failure(); + + Operation *maybeDefiningOp = value.getDefiningOp(); + Dialect *valueDialect = + maybeDefiningOp ? maybeDefiningOp->getDialect() + : value.getParentRegion()->getParentOp()->getDialect(); + Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); + Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr, + value.getType(), value.getLoc()); + if (!constant) + return failure(); + + value.replaceAllUsesWith(constant); + return success(); +} + +static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : llvm::reverse(region)) + worklist.push_back(&block); + }; + + OpBuilder builder(context); + OperationFolder folder(context); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + builder.setInsertionPoint(&op); + + // Replace any result with constants. + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= + succeeded(replaceWithConstant(analysis, builder, folder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + op.erase(); + continue; + } + + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + (void)replaceWithConstant(analysis, builder, folder, arg); + } +} + +namespace { +struct FoldInferredConstantsPass + : public FoldInferredConstantsBase { + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + IntRangeAnalysis analysis(ctx); + analysis.run(op); + rewrite(analysis, ctx, op->getRegions()); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::createFoldInferredConstantsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt -fold-inferred-constants %s | FileCheck %s + +// CHECK-LABEL: func @constant +// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index} +// CHECK: return %[[cst]] +func.func @constant() -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + func.return %0 : index +} + +// CHECK-LABEL: func @increment +// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index} +// CHECK: return %[[cst]] +func.func @increment() -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index } + %1 = test.increment %0 + func.return %1 : index +} + +// CHECK-LABEL: func @maybe_increment +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @maybe_increment(%arg0 : i1) -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + %1 = scf.if %arg0 -> index { + scf.yield %0 : index + } else { + %2 = test.increment %0 + scf.yield %2 : index + } + %3 = test.reflect_bounds %1 + func.return %3 : index +} + +// CHECK-LABEL: func @maybe_increment_br +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @maybe_increment_br(%arg0 : i1) -> index { + %0 = test.with_bounds { umin = 3 : index, umax = 3 : index, + smin = 3 : index, smax = 3 : index} + cf.cond_br %arg0, ^bb0, ^bb1 +^bb0: + %1 = test.increment %0 + cf.br ^bb2(%1 : index) +^bb1: + cf.br ^bb2(%0 : index) +^bb2(%2 : index): + %3 = test.reflect_bounds %2 + func.return %3 : index +} + +// CHECK-LABEL: func @for_bounds +// CHECK: test.reflect_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} +func.func @for_bounds() -> index { + %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index} + %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index, + smin = 1 : index, smax = 1 : index} + %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index, + smin = 2 : index, smax = 2 : index} + + %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index { + scf.yield %arg0 : index + } + %1 = test.reflect_bounds %0 + func.return %1 : index +} + +// CHECK-LABEL: func @no_analysis_of_loop_variants +// CHECK: test.reflect_bounds %{{.*}} +func.func @no_analysis_of_loop_variants() -> index { + %c0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index} + %c1 = test.with_bounds { umin = 1 : index, umax = 1 : index, + smin = 1 : index, smax = 1 : index} + %c2 = test.with_bounds { umin = 2 : index, umax = 2 : index, + smin = 2 : index, smax = 2 : index} + + %0 = scf.for %arg0 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index { + %1 = test.increment %arg2 + scf.yield %1 : index + } + %2 = test.reflect_bounds %0 + func.return %2 : index +} + +// CHECK-LABEL: func @region_args +// CHECK: test.reflect_bounds {smax = 4 : index, smin = 3 : index, umax = 4 : index, umin = 3 : index} +func.func @region_args() { + test.with_bounds_region { umin = 3 : index, umax = 4 : index, + smin = 3 : index, smax = 4 : index } %arg0 { + %0 = test.reflect_bounds %arg0 + } + func.return +} + +// CHECK-LABEL: func @func_args_unbound +// CHECK: test.reflect_bounds %{{.*}} +func.func @func_args_unbound(%arg0 : index) -> index { + %0 = test.reflect_bounds %arg0 + func.return %0 : index +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -62,6 +62,7 @@ MLIRFunc MLIRFuncTransforms MLIRIR + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRLinalg MLIRLinalgTransforms diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -33,6 +33,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/DerivedAttributeOpInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -14,15 +14,21 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -1402,6 +1408,81 @@ return success(); } +//===----------------------------------------------------------------------===// +// Test InferIntRangeInterface +//===----------------------------------------------------------------------===// + +void TestWithBoundsOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); +} + +ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the input argument + OpAsmParser::Argument argInfo; + argInfo.type = parser.getBuilder().getIndexType(); + parser.parseArgument(argInfo); + + // Parse the body region, and reuse the operand info as the argument info. + Region *body = result.addRegion(); + return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); +} + +void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { + p.printOptionalAttrDict((*this)->getAttrs()); + p << ' '; + p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, + /*omitType=*/true); + p << ' '; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + +void TestWithBoundsRegionOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + Value arg = getRegion().getArgument(0); + setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); +} + +void TestIncrementOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + const IntRangeAttrs &range = argRanges[0]; + auto inc = [](const APInt &value) -> APInt { return value + 1; }; + setResultRanges(getResult(), {range.umin().map(inc), range.umax().map(inc), + range.smin().map(inc), range.smax().map(inc)}); +} + +void TestReflectBoundsOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + const IntRangeAttrs &range = argRanges[0]; + MLIRContext *ctx = getContext(); + Builder b(ctx); + if (range.umin().hasValue()) + setUminAttr(b.getIndexAttr(range.umin()->getZExtValue())); + else + removeUminAttr(); + + if (range.umax().hasValue()) + setUmaxAttr(b.getIndexAttr(range.umax()->getZExtValue())); + else + removeUmaxAttr(); + + if (range.smin().hasValue()) + setSminAttr(b.getIndexAttr(range.smin()->getSExtValue())); + else + removeSminAttr(); + + if (range.smax().hasValue()) + setSmaxAttr(b.getIndexAttr(range.smax()->getSExtValue())); + else + removeSmaxAttr(); + + setResultRanges(getResult(), range); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -789,7 +790,7 @@ def CustomResultsNameOp : TEST_Op<"custom_result_name", [DeclareOpInterfaceMethods]> { - let arguments = (ins + let arguments = (ins Variadic:$optional, StrArrayAttr:$names ); @@ -2873,4 +2874,57 @@ }]; } +//===----------------------------------------------------------------------===// +// Test InferIntRangeInterface +//===----------------------------------------------------------------------===// +def TestWithBoundsOp : TEST_Op<"with_bounds", + [DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins OptionalAttr:$umin, + OptionalAttr:$umax, + OptionalAttr:$smin, + OptionalAttr:$smax); + let results = (outs Index:$fakeVal); + + let assemblyFormat = [{ + attr-dict + }]; +} + +def TestWithBoundsRegionOp : TEST_Op<"with_bounds_region", + [DeclareOpInterfaceMethods, + SingleBlock, NoTerminator]> { + let arguments = (ins OptionalAttr:$umin, + OptionalAttr:$umax, + OptionalAttr:$smin, + OptionalAttr:$smax); + // The region has one argument of index type + let regions = (region SizedRegion<1>:$region); + let hasCustomAssemblyFormat = 1; +} + +def TestIncrementOp : TEST_Op<"increment", + [DeclareOpInterfaceMethods, + NoSideEffect]> { + let arguments = (ins Index:$value); + let results = (outs Index:$result); + + let assemblyFormat = [{ + attr-dict $value + }]; +} + +def TestReflectBoundsOp : TEST_Op<"reflect_bounds", + [DeclareOpInterfaceMethods]> { + let arguments = (ins Index:$value, + OptionalAttr:$umin, + OptionalAttr:$umax, + OptionalAttr:$smin, + OptionalAttr:$smax); + let results = (outs Index:$result); + + let assemblyFormat = [{ + attr-dict $value + }]; +} #endif // TEST_OPS diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt --- a/mlir/unittests/Interfaces/CMakeLists.txt +++ b/mlir/unittests/Interfaces/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRInterfacesTests ControlFlowInterfacesTest.cpp DataLayoutInterfacesTest.cpp + InferIntRangeInterfaceTest.cpp InferTypeOpInterfaceTest.cpp ) @@ -10,6 +11,7 @@ MLIRDataLayoutInterfaces MLIRDLTI MLIRFunc + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRParser ) diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp @@ -0,0 +1,98 @@ +//===- InferIntRangeInterfaceTest.cpp - Unit Tests for InferIntRange... --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "llvm/ADT/APInt.h" +#include + +#include + +using namespace mlir; + +TEST(IntRangeAttrs, BasicConstructors) { + IntRangeAttrs blank; + EXPECT_FALSE(blank.umin().hasValue()); + EXPECT_FALSE(blank.umax().hasValue()); + EXPECT_FALSE(blank.smin().hasValue()); + EXPECT_FALSE(blank.smax().hasValue()); + + llvm::APInt zero = llvm::APInt::getZero(64); + llvm::APInt two(64, 2); + llvm::APInt three(64, 3); + IntRangeAttrs boundedAbove({}, two, {}, three); + EXPECT_EQ(boundedAbove.umin(), zero); + EXPECT_EQ(boundedAbove.umax(), two); + EXPECT_FALSE(boundedAbove.smin().hasValue()); + EXPECT_EQ(boundedAbove.smax(), three); +} + +TEST(IntRangeAttrs, FromUnsigned) { + llvm::APInt zero = llvm::APInt::getZero(64); + llvm::APInt maxInt = llvm::APInt::getSignedMaxValue(64); + llvm::APInt minInt = llvm::APInt::getSignedMinValue(64); + llvm::APInt minIntPlusOne = minInt + 1; + + IntRangeAttrs canPortToSigned = IntRangeAttrs::fromUnsigned(zero, maxInt); + EXPECT_EQ(canPortToSigned.smin(), zero); + EXPECT_EQ(canPortToSigned.smax(), maxInt); + + IntRangeAttrs cantPortToSigned = IntRangeAttrs::fromUnsigned(zero, minInt); + EXPECT_FALSE(cantPortToSigned.smin().hasValue()); + EXPECT_FALSE(cantPortToSigned.smax().hasValue()); + + IntRangeAttrs signedNegative = + IntRangeAttrs::fromUnsigned(minInt, minIntPlusOne); + EXPECT_EQ(signedNegative.smin(), minInt); + EXPECT_EQ(signedNegative.smax(), minIntPlusOne); +} + +TEST(IntRangeAttrs, FromSigned) { + llvm::APInt zero = llvm::APInt::getZero(64); + llvm::APInt one = zero + 1; + llvm::APInt negOne = zero - 1; + llvm::APInt intMax = llvm::APInt::getSignedMaxValue(64); + llvm::APInt intMin = llvm::APInt::getSignedMinValue(64); + + IntRangeAttrs noUnsignedBound = IntRangeAttrs::fromSigned(negOne, one); + EXPECT_FALSE(noUnsignedBound.umin().hasValue()); + EXPECT_FALSE(noUnsignedBound.umax().hasValue()); + + IntRangeAttrs positive = IntRangeAttrs::fromSigned(one, {}); + EXPECT_EQ(positive.umin(), one); + EXPECT_EQ(positive.umax(), intMax); + + IntRangeAttrs negative = IntRangeAttrs::fromSigned({}, negOne); + EXPECT_EQ(negative.umin(), intMin); + EXPECT_EQ(negative.umax(), negOne); + + IntRangeAttrs preserved = IntRangeAttrs::fromSigned(zero, one); + EXPECT_EQ(preserved.umin(), zero); + EXPECT_EQ(preserved.umax(), one); +} + +TEST(IntRangeAttrs, Join) { + llvm::APInt zero = llvm::APInt::getZero(64); + llvm::APInt one = zero + 1; + llvm::APInt two = zero + 2; + + IntRangeAttrs blank; + IntRangeAttrs zeroOne(zero, one, zero, one); + + EXPECT_EQ(IntRangeAttrs::join(zeroOne, blank), blank); + EXPECT_EQ(IntRangeAttrs::join(blank, zeroOne), blank); + + EXPECT_EQ(IntRangeAttrs::join(zeroOne, zeroOne), zeroOne); + + IntRangeAttrs oneTwo(one, two, one, two); + IntRangeAttrs zeroTwo(zero, two, zero, two); + EXPECT_EQ(IntRangeAttrs::join(zeroOne, oneTwo), zeroTwo); + + IntRangeAttrs zeroOneUnsigned(zero, one, {}, {}); + IntRangeAttrs zeroOneSigned({}, {}, zero, one); + EXPECT_EQ(IntRangeAttrs::join(zeroOneUnsigned, zeroOneSigned), blank); +}