diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -13,7 +13,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" -#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -12,7 +12,6 @@ include "mlir/Dialect/Index/IR/IndexDialect.td" include "mlir/Dialect/Index/IR/IndexEnums.td" include "mlir/Interfaces/CastInterfaces.td" -include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" @@ -24,8 +23,7 @@ /// Base class for Index dialect operations. class IndexOp traits = []> - : Op] # traits>; + : Op; //===----------------------------------------------------------------------===// // IndexBinaryOp diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h deleted file mode 100644 --- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h +++ /dev/null @@ -1,126 +0,0 @@ -//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares implementations of range inference for operations that are -// common to both the `arith` and `index` dialects to facilitate reuse. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H -#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H - -#include "mlir/Interfaces/InferIntRangeInterface.h" -#include "llvm/ADT/ArrayRef.h" - -namespace mlir { -namespace intrange { -/// Function that performs inference on an array of `ConstantIntRanges`, -/// abstracted away here to permit writing the function that handles both -/// 64- and 32-bit index types. -using InferRangeFn = - function_ref)>; - -static constexpr unsigned indexMinWidth = 32; -static constexpr unsigned indexMaxWidth = 64; - -enum class CmpMode : uint32_t { Both, Signed, Unsigned }; - -/// Compute `inferFn` on `ranges`, whose size should be the index storage -/// bitwidth. Then, compute the function on `argRanges` again after truncating -/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is -/// equal to the 32-bit result, use it (to preserve compatibility with folders -/// and inference precision), and take the union of the results otherwise. -/// -/// The `mode` argument specifies if the unsigned, signed, or both results of -/// the inference computation should be used when comparing the results. -ConstantIntRanges inferIndexOp(InferRangeFn inferFn, - ArrayRef argRanges, - CmpMode mode); - -/// Independently zero-extend the unsigned values and sign-extend the signed -/// values in `range` to `destWidth` bits, returning the resulting range. -ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth); - -/// Use the unsigned values in `range` to zero-extend it to `destWidth`. -ConstantIntRanges extUIRange(const ConstantIntRanges &range, - unsigned destWidth); - -/// Use the signed values in `range` to sign-extend it to `destWidth`. -ConstantIntRanges extSIRange(const ConstantIntRanges &range, - unsigned destWidth); - -/// Truncate `range` to `destWidth` bits, taking care to handle cases such as -/// the truncation of [255, 256] to i8 not being a uniform range. -ConstantIntRanges truncRange(const ConstantIntRanges &range, - unsigned destWidth); - -ConstantIntRanges inferAdd(ArrayRef argRanges); - -ConstantIntRanges inferSub(ArrayRef argRanges); - -ConstantIntRanges inferMul(ArrayRef argRanges); - -ConstantIntRanges inferDivS(ArrayRef argRanges); - -ConstantIntRanges inferDivU(ArrayRef argRanges); - -ConstantIntRanges inferCeilDivS(ArrayRef argRanges); - -ConstantIntRanges inferCeilDivU(ArrayRef argRanges); - -ConstantIntRanges inferFloorDivS(ArrayRef argRanges); - -ConstantIntRanges inferRemS(ArrayRef argRanges); - -ConstantIntRanges inferRemU(ArrayRef argRanges); - -ConstantIntRanges inferMaxS(ArrayRef argRanges); - -ConstantIntRanges inferMaxU(ArrayRef argRanges); - -ConstantIntRanges inferMinS(ArrayRef argRanges); - -ConstantIntRanges inferMinU(ArrayRef argRanges); - -ConstantIntRanges inferAnd(ArrayRef argRanges); - -ConstantIntRanges inferOr(ArrayRef argRanges); - -ConstantIntRanges inferXor(ArrayRef argRanges); - -ConstantIntRanges inferShl(ArrayRef argRanges); - -ConstantIntRanges inferShrS(ArrayRef argRanges); - -ConstantIntRanges inferShrU(ArrayRef argRanges); - -/// Copy of the enum from `arith` and `index` to allow the common integer range -/// infrastructure to not depend on either dialect. -enum class CmpPredicate : uint64_t { - eq, - ne, - slt, - sle, - sgt, - sge, - ult, - ule, - ugt, - uge, -}; - -/// Returns a boolean value if `pred` is statically true or false for -/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the -/// value of the predicate cannot be determined. -Optional evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs); - -} // namespace intrange -} // namespace mlir - -#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -16,7 +16,6 @@ LINK_LIBS PUBLIC MLIRDialect - MLIRInferIntRangeCommon MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/Support/Debug.h" #include @@ -17,7 +16,48 @@ using namespace mlir; using namespace mlir::arith; -using namespace mlir::intrange; + +/// Function that evaluates the result of doing something on arithmetic +/// constants and returns std::nullopt on overflow. +using ConstArithFn = + function_ref(const APInt &, const APInt &)>; + +/// Return the maxmially wide signed or unsigned range for a given bitwidth. + +/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, +/// If either computation overflows, make the result unbounded. +static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, + const APInt &minRight, + const APInt &maxLeft, + const APInt &maxRight, bool isSigned) { + std::optional maybeMin = op(minLeft, minRight); + std::optional maybeMax = op(maxLeft, maxRight); + if (maybeMin && maybeMax) + return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); + return ConstantIntRanges::maxRange(minLeft.getBitWidth()); +} + +/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, +/// ignoring unbounded values. Returns the maximal range if `op` overflows. +static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, + ArrayRef rhs, bool isSigned) { + unsigned width = lhs[0].getBitWidth(); + APInt min = + isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); + APInt max = + isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); + for (const APInt &left : lhs) { + for (const APInt &right : rhs) { + std::optional maybeThisResult = op(left, right); + if (!maybeThisResult) + return ConstantIntRanges::maxRange(width); + APInt result = std::move(*maybeThisResult); + min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; + max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; + } + } + return ConstantIntRanges::range(min, max, isSigned); +} //===----------------------------------------------------------------------===// // ConstantOp @@ -38,7 +78,25 @@ void arith::AddIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferAdd(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn uadd = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.uadd_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn sadd = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.sadd_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + + ConstantIntRanges urange = computeBoundsBy( + uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// @@ -47,7 +105,25 @@ void arith::SubIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferSub(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn usub = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.usub_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn ssub = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.ssub_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstantIntRanges urange = computeBoundsBy( + usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// @@ -56,25 +132,96 @@ void arith::MulIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMul(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn umul = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.umul_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + ConstArithFn smul = [](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.smul_ov(b, overflowed); + return overflowed ? std::optional() : result; + }; + + ConstantIntRanges urange = + minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, + /*isSigned=*/true); + + setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// +/// Fix up division results (ex. for ceiling and floor), returning an APInt +/// if there has been no overflow +using DivisionFixupFn = function_ref( + const APInt &lhs, const APInt &rhs, const APInt &result)>; + +static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), + &rhsMax = rhs.umax(); + + if (!rhsMin.isZero()) { + auto udiv = [&fixup](const APInt &a, + const APInt &b) -> std::optional { + return fixup(a, b, a.udiv(b)); + }; + return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/false); + } + // Otherwise, it's possible we might divide by 0. + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferDivU(argRanges)); + setResultRange(getResult(), + inferDivUIRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; })); } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); + + if (canDivide) { + auto sdiv = [&fixup](const APInt &a, + const APInt &b) -> std::optional { + bool overflowed = false; + APInt result = a.sdiv_ov(b, overflowed); + return overflowed ? std::optional() : fixup(a, b, result); + }; + return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/true); + } + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferDivS(argRanges)); + setResultRange(getResult(), + inferDivSIRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; })); } //===----------------------------------------------------------------------===// @@ -83,7 +230,20 @@ void arith::CeilDivUIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferCeilDivU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivUIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.urem(rhs).isZero()) { + bool overflowed = false; + APInt corrected = + result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); } //===----------------------------------------------------------------------===// @@ -92,7 +252,20 @@ void arith::CeilDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferCeilDivS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivSIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); } //===----------------------------------------------------------------------===// @@ -101,7 +274,20 @@ void arith::FloorDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - return setResultRange(getResult(), inferFloorDivS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn floorDivSIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); } //===----------------------------------------------------------------------===// @@ -110,7 +296,29 @@ void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferRemU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); + + unsigned width = rhsMin.getBitWidth(); + APInt umin = APInt::getZero(width); + APInt umax = APInt::getMaxValue(width); + + if (!rhsMin.isZero()) { + umax = rhsMax - 1; + // Special case: sweeping out a contiguous range in N/[modulus] + if (rhsMin == rhsMax) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); + if ((lhsMax - lhsMin).ult(rhsMax)) { + APInt minRem = lhsMin.urem(rhsMax); + APInt maxRem = lhsMax.urem(rhsMax); + if (minRem.ule(maxRem)) { + umin = minRem; + umax = maxRem; + } + } + } + } + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// @@ -119,16 +327,67 @@ void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferRemS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + + unsigned width = rhsMax.getBitWidth(); + APInt smin = APInt::getSignedMinValue(width); + APInt smax = APInt::getSignedMaxValue(width); + // No bounds if zero could be a divisor. + bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); + if (canBound) { + APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); + bool canNegativeDividend = lhsMin.isNegative(); + bool canPositiveDividend = lhsMax.isStrictlyPositive(); + APInt zero = APInt::getZero(maxDivisor.getBitWidth()); + APInt maxPositiveResult = maxDivisor - 1; + APInt minNegativeResult = -maxPositiveResult; + smin = canNegativeDividend ? minNegativeResult : zero; + smax = canPositiveDividend ? maxPositiveResult : zero; + // Special case: sweeping out a contiguous range in N/[modulus]. + if (rhsMin == rhsMax) { + if ((lhsMax - lhsMin).ult(maxDivisor)) { + APInt minRem = lhsMin.srem(maxDivisor); + APInt maxRem = lhsMax.srem(maxDivisor); + if (minRem.sle(maxRem)) { + smin = minRem; + smax = maxRem; + } + } + } + } + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// +/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +/// that both bonuds have in common. This gives us a consertive approximation +/// for what values can be passed to bitwise operations. +static std::tuple +widenBitwiseBounds(const ConstantIntRanges &bound) { + APInt leftVal = bound.umin(), rightVal = bound.umax(); + unsigned bitwidth = leftVal.getBitWidth(); + unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + return std::make_tuple(std::move(leftVal), std::move(rightVal)); +} + void arith::AndIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferAnd(argRanges)); + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto andi = [](const APInt &a, const APInt &b) -> std::optional { + return a & b; + }; + setResultRange(getResult(), + minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); } //===----------------------------------------------------------------------===// @@ -137,7 +396,14 @@ void arith::OrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferOr(argRanges)); + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto ori = [](const APInt &a, const APInt &b) -> std::optional { + return a | b; + }; + setResultRange(getResult(), + minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); } //===----------------------------------------------------------------------===// @@ -146,7 +412,14 @@ void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferXor(argRanges)); + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto xori = [](const APInt &a, const APInt &b) -> std::optional { + return a ^ b; + }; + setResultRange(getResult(), + minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); } //===----------------------------------------------------------------------===// @@ -155,7 +428,11 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMaxS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// @@ -164,7 +441,11 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMaxU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// @@ -173,7 +454,11 @@ void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMinS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// @@ -182,40 +467,94 @@ void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferMinU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges extUIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + APInt umin = range.umin().zext(destWidth); + APInt umax = range.umax().zext(destWidth); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - unsigned destWidth = - ConstantIntRanges::getStorageBitwidth(getResult().getType()); - setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); + Type destType = getResult().getType(); + setResultRange(getResult(), extUIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges extSIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return ConstantIntRanges::fromSigned(smin, smax); +} + void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - unsigned destWidth = - ConstantIntRanges::getStorageBitwidth(getResult().getType()); - setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); + Type destType = getResult().getType(); + setResultRange(getResult(), extSIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// +static ConstantIntRanges truncIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], + // the range of the resulting value is not contiguous ind includes 0. + // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], + // but you can't truncate [255, 257] similarly. + bool hasUnsignedRollover = + range.umin().lshr(destWidth) != range.umax().lshr(destWidth); + APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) + : range.umin().trunc(destWidth); + APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) + : range.umax().trunc(destWidth); + + // Signed post-truncation rollover will not occur when either: + // - The high parts of the min and max, plus the sign bit, are the same + // - The high halves + sign bit of the min and max are either all 1s or all 0s + // and you won't create a [positive, negative] range by truncating. + // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 + // but not [255, 257]_i16 to a range of i8s. You can also truncate + // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. + // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) + // will truncate to 0x7e, which is greater than 0 + APInt sminHighPart = range.smin().ashr(destWidth - 1); + APInt smaxHighPart = range.smax().ashr(destWidth - 1); + bool hasSignedOverflow = + (sminHighPart != smaxHighPart) && + !(sminHighPart.isAllOnes() && + (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && + !(sminHighPart.isZero() && smaxHighPart.isZero()); + APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) + : range.smin().trunc(destWidth); + APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) + : range.smax().trunc(destWidth); + return {umin, umax, smin, smax}; +} + void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - unsigned destWidth = - ConstantIntRanges::getStorageBitwidth(getResult().getType()); - setResultRange(getResult(), truncRange(argRanges[0], destWidth)); + Type destType = getResult().getType(); + setResultRange(getResult(), truncIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// @@ -230,9 +569,9 @@ unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); + setResultRange(getResult(), extSIRange(argRanges[0], destType)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncRange(argRanges[0], destWidth)); + setResultRange(getResult(), truncIRange(argRanges[0], destType)); else setResultRange(getResult(), argRanges[0]); } @@ -249,9 +588,9 @@ unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); + setResultRange(getResult(), extUIRange(argRanges[0], destType)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncRange(argRanges[0], destWidth)); + setResultRange(getResult(), truncIRange(argRanges[0], destType)); else setResultRange(getResult(), argRanges[0]); } @@ -260,19 +599,51 @@ // CmpIOp //===----------------------------------------------------------------------===// +bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + switch (pred) { + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: + return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sgt: + return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); + case arith::CmpIPredicate::eq: { + std::optional lhsConst = lhs.getConstantValue(); + std::optional rhsConst = rhs.getConstantValue(); + return lhsConst && rhsConst && lhsConst == rhsConst; + } + case arith::CmpIPredicate::ne: { + // While equality requires that there is an interpration of the preceeding + // computations that produces equal constants, whether that be signed or + // unsigned, statically determining inequality requires that neither + // interpretation produce potentially overlapping ranges. + bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); + bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); + return sne && une; + } + } + return false; +} + void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - arith::CmpIPredicate arithPred = getPredicate(); - intrange::CmpPredicate pred = static_cast(arithPred); + arith::CmpIPredicate pred = getPredicate(); const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; APInt min = APInt::getZero(1); APInt max = APInt::getAllOnesValue(1); - - Optional truthValue = intrange::evaluatePred(pred, lhs, rhs); - if (truthValue.has_value() && *truthValue) + if (isStaticallyTrue(pred, lhs, rhs)) min = max; - else if (truthValue.has_value() && !(*truthValue)) + else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) max = min; setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); @@ -302,7 +673,18 @@ void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferShl(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn shl = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); + }; + ConstantIntRanges urange = + minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// @@ -311,7 +693,15 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferShrU(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn lshr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); + }; + setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, + {rhs.umin(), rhs.umax()}, + /*isSigned=*/false)); } //===----------------------------------------------------------------------===// @@ -320,5 +710,14 @@ void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferShrS(argRanges)); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn ashr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); + }; + + setResultRange(getResult(), + minMaxBy(ashr, {lhs.smin(), lhs.smax()}, + {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); } diff --git a/mlir/lib/Dialect/Index/IR/CMakeLists.txt b/mlir/lib/Dialect/Index/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Index/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Index/IR/CMakeLists.txt @@ -2,7 +2,6 @@ IndexAttrs.cpp IndexDialect.cpp IndexOps.cpp - InferIntRangeInterfaceImpls.cpp DEPENDS MLIRIndexOpsIncGen @@ -11,7 +10,6 @@ MLIRDialect MLIRIR MLIRCastInterfaces - MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp +++ /dev/null @@ -1,252 +0,0 @@ -//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Interfaces/InferIntRangeInterface.h" -#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" - -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "int-range-analysis" - -using namespace mlir; -using namespace mlir::index; -using namespace mlir::intrange; - -//===----------------------------------------------------------------------===// -// Constants -//===----------------------------------------------------------------------===// - -void ConstantOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - const APInt &value = getValue(); - setResultRange(getResult(), ConstantIntRanges::constant(value)); -} - -void BoolConstantOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - bool value = getValue(); - APInt asInt(/*numBits=*/1, value); - setResultRange(getResult(), ConstantIntRanges::constant(asInt)); -} - -//===----------------------------------------------------------------------===// -// Arithmec operations. All of these operations will have their results inferred -// using both the 64-bit values and truncated 32-bit values of their inputs, -// with the results being the union of those inferences, except where the -// truncation of the 64-bit result is equal to the 32-bit result (at which time -// we take the 64-bit result). -//===----------------------------------------------------------------------===// - -void AddOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both)); -} - -void SubOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both)); -} - -void MulOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both)); -} - -void DivUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned)); -} - -void DivSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferDivS, argRanges, CmpMode::Signed)); -} - -void CeilDivUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned)); -} - -void CeilDivSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed)); -} - -void FloorDivSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - return setResultRange( - getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed)); -} - -void RemSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferRemS, argRanges, CmpMode::Signed)); -} - -void RemUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned)); -} - -void MaxSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferMaxS, argRanges, CmpMode::Signed)); -} - -void MaxUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned)); -} - -void MinSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferMinS, argRanges, CmpMode::Signed)); -} - -void MinUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned)); -} - -void ShlOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both)); -} - -void ShrSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferShrS, argRanges, CmpMode::Signed)); -} - -void ShrUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned)); -} - -void AndOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned)); -} - -void OrOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferOr, argRanges, CmpMode::Unsigned)); -} - -void XOrOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferIndexOp(inferXor, argRanges, CmpMode::Unsigned)); -} - -//===----------------------------------------------------------------------===// -// Casts -//===----------------------------------------------------------------------===// - -static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, - unsigned srcWidth, unsigned destWidth, - bool isSigned) { - if (srcWidth < destWidth) - return isSigned ? extSIRange(range, destWidth) - : extUIRange(range, destWidth); - if (srcWidth > destWidth) - return truncRange(range, destWidth); - return range; -} - -// When casting to `index`, we will take the union of the possible fixed-width -// casts. -static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, - Type sourceType, Type destType, - bool isSigned) { - unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - if (sourceType.isIndex()) - return makeLikeDest(range, srcWidth, destWidth, isSigned); - // We are casting to indexs, so use the union of the 32-bit and 64-bit casts - ConstantIntRanges storageRange = - makeLikeDest(range, srcWidth, destWidth, isSigned); - ConstantIntRanges minWidthRange = - makeLikeDest(range, srcWidth, indexMinWidth, isSigned); - ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth); - ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt); - return ret; -} - -void CastSOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - Type sourceType = getOperand().getType(); - Type destType = getResult().getType(); - setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, - /*isSigned=*/true)); -} - -void CastUOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - Type sourceType = getOperand().getType(); - Type destType = getResult().getType(); - setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, - /*isSigned=*/false)); -} - -//===----------------------------------------------------------------------===// -// CmpOp -//===----------------------------------------------------------------------===// - -void CmpOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - index::IndexCmpPredicate indexPred = getPred(); - intrange::CmpPredicate pred = static_cast(indexPred); - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - APInt min = APInt::getZero(1); - APInt max = APInt::getAllOnesValue(1); - - Optional truthValue64 = intrange::evaluatePred(pred, lhs, rhs); - - ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth), - rhsTrunc = truncRange(rhs, indexMinWidth); - Optional truthValue32 = - intrange::evaluatePred(pred, lhsTrunc, rhsTrunc); - - if (truthValue64 == truthValue32) { - if (truthValue64.has_value() && *truthValue64) - min = max; - else if (truthValue64.has_value() && !(*truthValue64)) - max = min; - } - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); -} - -//===----------------------------------------------------------------------===// -// SizeOf, which is bounded between the two supported bitwidth (32 and 64). -//===----------------------------------------------------------------------===// - -void SizeOfOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - unsigned storageWidth = - ConstantIntRanges::getStorageBitwidth(getResult().getType()); - APInt min(/*numBits=*/storageWidth, indexMinWidth); - APInt max(/*numBits=*/storageWidth, indexMaxWidth); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); -} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -51,5 +51,3 @@ add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) - -add_subdirectory(Utils) diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Interfaces/Utils/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_mlir_library(MLIRInferIntRangeCommon - InferIntRangeCommon.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils - - DEPENDS - MLIRInferIntRangeInterfaceIncGen - - LINK_LIBS PUBLIC - MLIRInferIntRangeInterface - MLIRIR -) diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp deleted file mode 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ /dev/null @@ -1,663 +0,0 @@ -//===- InferIntRangeCommon.cpp - Inference for common ops ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains implementations of range inference for operations that are -// common to both the `arith` and `index` dialects to facilitate reuse. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" - -#include "mlir/Interfaces/InferIntRangeInterface.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" - -#include "llvm/Support/Debug.h" - -#include -#include - -using namespace mlir; - -#define DEBUG_TYPE "int-range-analysis" - -//===----------------------------------------------------------------------===// -// General utilities -//===----------------------------------------------------------------------===// - -/// Function that evaluates the result of doing something on arithmetic -/// constants and returns std::nullopt on overflow. -using ConstArithFn = - function_ref(const APInt &, const APInt &)>; - -/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, -/// If either computation overflows, make the result unbounded. -static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, - const APInt &minRight, - const APInt &maxLeft, - const APInt &maxRight, bool isSigned) { - std::optional maybeMin = op(minLeft, minRight); - std::optional maybeMax = op(maxLeft, maxRight); - if (maybeMin && maybeMax) - return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); - return ConstantIntRanges::maxRange(minLeft.getBitWidth()); -} - -/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, -/// ignoring unbounded values. Returns the maximal range if `op` overflows. -static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, - ArrayRef rhs, bool isSigned) { - unsigned width = lhs[0].getBitWidth(); - APInt min = - isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); - APInt max = - isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); - for (const APInt &left : lhs) { - for (const APInt &right : rhs) { - std::optional maybeThisResult = op(left, right); - if (!maybeThisResult) - return ConstantIntRanges::maxRange(width); - APInt result = std::move(*maybeThisResult); - min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; - max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; - } - } - return ConstantIntRanges::range(min, max, isSigned); -} - -//===----------------------------------------------------------------------===// -// Ext, trunc, index op handling -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferIndexOp(InferRangeFn inferFn, - ArrayRef argRanges, - intrange::CmpMode mode) { - ConstantIntRanges sixtyFour = inferFn(argRanges); - SmallVector truncated; - llvm::transform(argRanges, std::back_inserter(truncated), - [](const ConstantIntRanges &range) { - return truncRange(range, /*destWidth=*/indexMinWidth); - }); - ConstantIntRanges thirtyTwo = inferFn(truncated); - ConstantIntRanges thirtyTwoAsSixtyFour = - extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); - ConstantIntRanges sixtyFourAsThirtyTwo = - truncRange(sixtyFour, /*destWidth=*/indexMinWidth); - - LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour - << " 32-bit = " << thirtyTwo << "\n"); - bool truncEqual = false; - switch (mode) { - case intrange::CmpMode::Both: - truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); - break; - case intrange::CmpMode::Signed: - truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && - thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); - break; - case intrange::CmpMode::Unsigned: - truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && - thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); - break; - } - if (truncEqual) - // Returing the 64-bit result preserves more information. - return sixtyFour; - ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); - return merged; -} - -ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, - unsigned int destWidth) { - APInt umin = range.umin().zext(destWidth); - APInt umax = range.umax().zext(destWidth); - APInt smin = range.smin().sext(destWidth); - APInt smax = range.smax().sext(destWidth); - return {umin, umax, smin, smax}; -} - -ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, - unsigned destWidth) { - APInt umin = range.umin().zext(destWidth); - APInt umax = range.umax().zext(destWidth); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - -ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, - unsigned destWidth) { - APInt smin = range.smin().sext(destWidth); - APInt smax = range.smax().sext(destWidth); - return ConstantIntRanges::fromSigned(smin, smax); -} - -ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, - unsigned int destWidth) { - // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], - // the range of the resulting value is not contiguous ind includes 0. - // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], - // but you can't truncate [255, 257] similarly. - bool hasUnsignedRollover = - range.umin().lshr(destWidth) != range.umax().lshr(destWidth); - APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) - : range.umin().trunc(destWidth); - APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) - : range.umax().trunc(destWidth); - - // Signed post-truncation rollover will not occur when either: - // - The high parts of the min and max, plus the sign bit, are the same - // - The high halves + sign bit of the min and max are either all 1s or all 0s - // and you won't create a [positive, negative] range by truncating. - // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 - // but not [255, 257]_i16 to a range of i8s. You can also truncate - // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. - // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) - // will truncate to 0x7e, which is greater than 0 - APInt sminHighPart = range.smin().ashr(destWidth - 1); - APInt smaxHighPart = range.smax().ashr(destWidth - 1); - bool hasSignedOverflow = - (sminHighPart != smaxHighPart) && - !(sminHighPart.isAllOnes() && - (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && - !(sminHighPart.isZero() && smaxHighPart.isZero()); - APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) - : range.smin().trunc(destWidth); - APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) - : range.smax().trunc(destWidth); - return {umin, umax, smin, smax}; -} - -//===----------------------------------------------------------------------===// -// Addition -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferAdd(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn uadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.uadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn sadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = computeBoundsBy( - uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); - return urange.intersection(srange); -} - -//===----------------------------------------------------------------------===// -// Subtraction -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferSub(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn usub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.usub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn ssub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.ssub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstantIntRanges urange = computeBoundsBy( - usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); - return urange.intersection(srange); -} - -//===----------------------------------------------------------------------===// -// Multiplication -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferMul(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn umul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.umul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn smul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.smul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = - minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, - /*isSigned=*/true); - return urange.intersection(srange); -} - -//===----------------------------------------------------------------------===// -// DivU, CeilDivU (Unsigned division) -//===----------------------------------------------------------------------===// - -/// Fix up division results (ex. for ceiling and floor), returning an APInt -/// if there has been no overflow -using DivisionFixupFn = function_ref( - const APInt &lhs, const APInt &rhs, const APInt &result)>; - -static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), - &rhsMax = rhs.umax(); - - if (!rhsMin.isZero()) { - auto udiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - return fixup(a, b, a.udiv(b)); - }; - return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/false); - } - // Otherwise, it's possible we might divide by 0. - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - -ConstantIntRanges -mlir::intrange::inferDivU(ArrayRef argRanges) { - return inferDivURange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; }); -} - -ConstantIntRanges -mlir::intrange::inferCeilDivU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivUIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.urem(rhs).isZero()) { - bool overflowed = false; - APInt corrected = - result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - return inferDivURange(lhs, rhs, ceilDivUIFix); -} - -//===----------------------------------------------------------------------===// -// DivS, CeilDivS, FloorDivS (Signed division) -//===----------------------------------------------------------------------===// - -static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); - - if (canDivide) { - auto sdiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sdiv_ov(b, overflowed); - return overflowed ? std::optional() : fixup(a, b, result); - }; - return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/true); - } - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - -ConstantIntRanges -mlir::intrange::inferDivS(ArrayRef argRanges) { - return inferDivSRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; }); -} - -ConstantIntRanges -mlir::intrange::inferCeilDivS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - return inferDivSRange(lhs, rhs, ceilDivSIFix); -} - -ConstantIntRanges -mlir::intrange::inferFloorDivS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn floorDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - return inferDivSRange(lhs, rhs, floorDivSIFix); -} - -//===----------------------------------------------------------------------===// -// Signed remainder (RemS) -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferRemS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - - unsigned width = rhsMax.getBitWidth(); - APInt smin = APInt::getSignedMinValue(width); - APInt smax = APInt::getSignedMaxValue(width); - // No bounds if zero could be a divisor. - bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); - if (canBound) { - APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); - bool canNegativeDividend = lhsMin.isNegative(); - bool canPositiveDividend = lhsMax.isStrictlyPositive(); - APInt zero = APInt::getZero(maxDivisor.getBitWidth()); - APInt maxPositiveResult = maxDivisor - 1; - APInt minNegativeResult = -maxPositiveResult; - smin = canNegativeDividend ? minNegativeResult : zero; - smax = canPositiveDividend ? maxPositiveResult : zero; - // Special case: sweeping out a contiguous range in N/[modulus]. - if (rhsMin == rhsMax) { - if ((lhsMax - lhsMin).ult(maxDivisor)) { - APInt minRem = lhsMin.srem(maxDivisor); - APInt maxRem = lhsMax.srem(maxDivisor); - if (minRem.sle(maxRem)) { - smin = minRem; - smax = maxRem; - } - } - } - } - return ConstantIntRanges::fromSigned(smin, smax); -} - -//===----------------------------------------------------------------------===// -// Unsigned remainder (RemU) -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferRemU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); - - unsigned width = rhsMin.getBitWidth(); - APInt umin = APInt::getZero(width); - APInt umax = APInt::getMaxValue(width); - - if (!rhsMin.isZero()) { - umax = rhsMax - 1; - // Special case: sweeping out a contiguous range in N/[modulus] - if (rhsMin == rhsMax) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); - if ((lhsMax - lhsMin).ult(rhsMax)) { - APInt minRem = lhsMin.urem(rhsMax); - APInt maxRem = lhsMax.urem(rhsMax); - if (minRem.ule(maxRem)) { - umin = minRem; - umax = maxRem; - } - } - } - } - return ConstantIntRanges::fromUnsigned(umin, umax); -} - -//===----------------------------------------------------------------------===// -// Max and min (MaxS, MaxU, MinS, MinU) -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferMaxS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); - return ConstantIntRanges::fromSigned(smin, smax); -} - -ConstantIntRanges -mlir::intrange::inferMaxU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - -ConstantIntRanges -mlir::intrange::inferMinS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); - return ConstantIntRanges::fromSigned(smin, smax); -} - -ConstantIntRanges -mlir::intrange::inferMinU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - -//===----------------------------------------------------------------------===// -// Bitwise operators (And, Or, Xor) -//===----------------------------------------------------------------------===// - -/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, -/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits -/// that both bonuds have in common. This gives us a consertive approximation -/// for what values can be passed to bitwise operations. -static std::tuple -widenBitwiseBounds(const ConstantIntRanges &bound) { - APInt leftVal = bound.umin(), rightVal = bound.umax(); - unsigned bitwidth = leftVal.getBitWidth(); - unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); - leftVal.clearLowBits(differingBits); - rightVal.setLowBits(differingBits); - return std::make_tuple(std::move(leftVal), std::move(rightVal)); -} - -ConstantIntRanges -mlir::intrange::inferAnd(ArrayRef argRanges) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto andi = [](const APInt &a, const APInt &b) -> std::optional { - return a & b; - }; - return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); -} - -ConstantIntRanges -mlir::intrange::inferOr(ArrayRef argRanges) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto ori = [](const APInt &a, const APInt &b) -> std::optional { - return a | b; - }; - return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); -} - -ConstantIntRanges -mlir::intrange::inferXor(ArrayRef argRanges) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto xori = [](const APInt &a, const APInt &b) -> std::optional { - return a ^ b; - }; - return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false); -} - -//===----------------------------------------------------------------------===// -// Shifts (Shl, ShrS, ShrU) -//===----------------------------------------------------------------------===// - -ConstantIntRanges -mlir::intrange::inferShl(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn shl = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); - }; - ConstantIntRanges urange = - minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/true); - return urange.intersection(srange); -} - -ConstantIntRanges -mlir::intrange::inferShrS(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn ashr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); - }; - - return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/true); -} - -ConstantIntRanges -mlir::intrange::inferShrU(ArrayRef argRanges) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn lshr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); - }; - return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); -} - -//===----------------------------------------------------------------------===// -// Comparisons (Cmp) -//===----------------------------------------------------------------------===// - -static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { - switch (pred) { - case intrange::CmpPredicate::eq: - return intrange::CmpPredicate::ne; - case intrange::CmpPredicate::ne: - return intrange::CmpPredicate::eq; - case intrange::CmpPredicate::slt: - return intrange::CmpPredicate::sge; - case intrange::CmpPredicate::sle: - return intrange::CmpPredicate::sgt; - case intrange::CmpPredicate::sgt: - return intrange::CmpPredicate::sle; - case intrange::CmpPredicate::sge: - return intrange::CmpPredicate::slt; - case intrange::CmpPredicate::ult: - return intrange::CmpPredicate::uge; - case intrange::CmpPredicate::ule: - return intrange::CmpPredicate::ugt; - case intrange::CmpPredicate::ugt: - return intrange::CmpPredicate::ule; - case intrange::CmpPredicate::uge: - return intrange::CmpPredicate::ult; - } - llvm_unreachable("unknown cmp predicate value"); -} - -static bool isStaticallyTrue(intrange::CmpPredicate pred, - const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - switch (pred) { - case intrange::CmpPredicate::sle: - return lhs.smax().sle(rhs.smin()); - case intrange::CmpPredicate::slt: - return lhs.smax().slt(rhs.smin()); - case intrange::CmpPredicate::ule: - return lhs.umax().ule(rhs.umin()); - case intrange::CmpPredicate::ult: - return lhs.umax().ult(rhs.umin()); - case intrange::CmpPredicate::sge: - return lhs.smin().sge(rhs.smax()); - case intrange::CmpPredicate::sgt: - return lhs.smin().sgt(rhs.smax()); - case intrange::CmpPredicate::uge: - return lhs.umin().uge(rhs.umax()); - case intrange::CmpPredicate::ugt: - return lhs.umin().ugt(rhs.umax()); - case intrange::CmpPredicate::eq: { - std::optional lhsConst = lhs.getConstantValue(); - std::optional rhsConst = rhs.getConstantValue(); - return lhsConst && rhsConst && lhsConst == rhsConst; - } - case intrange::CmpPredicate::ne: { - // While equality requires that there is an interpration of the preceeding - // computations that produces equal constants, whether that be signed or - // unsigned, statically determining inequality requires that neither - // interpretation produce potentially overlapping ranges. - bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || - isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); - bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || - isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); - return sne && une; - } - } - return false; -} - -std::optional mlir::intrange::evaluatePred(CmpPredicate pred, - const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - if (isStaticallyTrue(pred, lhs, rhs)) - return true; - if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) - return false; - return std::nullopt; -} diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Index/int-range-inference.mlir +++ /dev/null @@ -1,66 +0,0 @@ -// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s - -// Most operations are covered by the `arith` tests, which use the same code -// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling -// code is operating as expected. - -// CHECK-LABEL: func @add_same_for_both -// CHECK: %[[true:.*]] = index.bool.constant true -// CHECK: return %[[true]] -func.func @add_same_for_both(%arg0 : index) -> i1 { - %c1 = index.constant 1 - %calmostBig = index.constant 0xfffffffe - %0 = index.minu %arg0, %calmostBig - %1 = index.add %0, %c1 - %2 = index.cmp uge(%1, %c1) - func.return %2 : i1 -} - -// CHECK-LABEL: func @add_unsigned_ov -// CHECK: %[[uge:.*]] = index.cmp uge -// CHECK: return %[[uge]] -func.func @add_unsigned_ov(%arg0 : index) -> i1 { - %c1 = index.constant 1 - %cu32_max = index.constant 0xffffffff - %0 = index.minu %arg0, %cu32_max - %1 = index.add %0, %c1 - // On 32-bit, the add could wrap, so the result doesn't have to be >= 1 - %2 = index.cmp uge(%1, %c1) - func.return %2 : i1 -} - -// CHECK-LABEL: func @add_signed_ov -// CHECK: %[[sge:.*]] = index.cmp sge -// CHECK: return %[[sge]] -func.func @add_signed_ov(%arg0 : index) -> i1 { - %c0 = index.constant 0 - %c1 = index.constant 1 - %ci32_max = index.constant 0x7fffffff - %0 = index.minu %arg0, %ci32_max - %1 = index.add %0, %c1 - // On 32-bit, the add could wrap, so the result doesn't have to be positive - %2 = index.cmp sge(%1, %c0) - func.return %2 : i1 -} - -// CHECK-LABEL: func @add_big -// CHECK: %[[true:.*]] = index.bool.constant true -// CHECK: return %[[true]] -func.func @add_big(%arg0 : index) -> i1 { - %c1 = index.constant 1 - %cmin = index.constant 0x300000000 - %cmax = index.constant 0x30000ffff - // Note: the order of the clamps matters. - // If you go max, then min, you infer the ranges [0x300...0, 0xff..ff] - // and then [0x30...0000, 0x30...ffff] - // If you switch the order of the below operations, you instead first infer - // the range [0,0x3...ffff]. Then, the min inference can't constraint - // this intermediate, since in the 32-bit case we could have, for example - // trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff - // which means we can't do any inference. - %0 = index.maxu %arg0, %cmin - %1 = index.minu %0, %cmax - %2 = index.add %1, %c1 - %3 = index.cmp uge(%1, %cmin) - func.return %3 : i1 -}